neural_tangents.stax.LayerNorm
- neural_tangents.stax.LayerNorm(axis=- 1, eps=1e-12, batch_axis=0, channel_axis=- 1)[source]
Layer normalisation.
- Parameters
axis (
Union
[int
,Sequence
[int
]]) – dimensions over which to normalize.eps (
float
) – (small) positive constant to be added to the variance estimates in order to prevent division by zero.batch_axis (
int
) – batch dimension. Defaults to 0, the leading axis.channel_axis (
int
) – channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.
- Return type
Tuple
[InitFn
,ApplyFn
,LayerKernelFn
]- Returns
(init_fn, apply_fn, kernel_fn).