neural_tangents.stax.FanInConcat

neural_tangents.stax.FanInConcat(axis=- 1)[source]

Layer construction function for a fan-in concatenation layer.

Based on jax.example_libraries.stax.FanInConcat.

Parameters

axis (int) – Specifies the axis along which input tensors should be concatenated.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns

(init_fn, apply_fn, kernel_fn).