neural_tangents.stax.Flatten

neural_tangents.stax.Flatten(batch_axis=0, batch_axis_out=0)[source]

Layer construction function for flattening all non-batch dimensions.

Based on jax.example_libraries.stax.Flatten, but allows to specify batch axes.

Parameters
  • batch_axis (int) – Specifies the input batch dimension. Defaults to 0, the leading axis.

  • batch_axis_out (int) – Specifies the output batch dimension. Defaults to 0, the leading axis.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns

(init_fn, apply_fn, kernel_fn).