Topics:
Layer construction function for flattening all non-batch dimensions.
Based on jax.example_libraries.stax.Flatten, but allows to specify batch axes.
batch_axis (int) – Specifies the input batch dimension. Defaults to 0, the leading axis.
int
batch_axis_out (int) – Specifies the output batch dimension. Defaults to 0, the leading axis.
Tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]
Tuple
InitFn
ApplyFn
LayerKernelFn
MaskFn
(init_fn, apply_fn, kernel_fn).