neural_tangents.stax.LayerNorm

neural_tangents.stax.LayerNorm(axis=- 1, eps=1e-12, batch_axis=0, channel_axis=- 1)[source]

Layer normalisation.

Parameters
  • axis (Union[int, Sequence[int]]) – dimensions over which to normalize.

  • eps (float) – (small) positive constant to be added to the variance estimates in order to prevent division by zero.

  • batch_axis (int) – batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn]

Returns

(init_fn, apply_fn, kernel_fn).