neural_tangents.stax.layer

neural_tangents.stax.layer(layer_fn)[source]

A convenience decorator to be added to all public layers.

Used in Relu etc.

Makes the kernel_fn of the layer work with both input jax.numpy.ndarray (when the layer is the first one applied to inputs), and with Kernel for intermediary layers. Also adds optional arguments to the kernel_fn to allow specifying the computation and returned results with more flexibility.

Parameters:

layer_fn (Callable[..., tuple[InitFn, ApplyFn, LayerKernelFn]]) – Layer function returning triple (init_fn, apply_fn, kernel_fn).

Return type:

Callable[..., tuple[InitFn, ApplyFn, AnalyticKernelFn]]

Returns:

A function with the same signature as layer with kernel_fn now accepting jax.numpy.ndarray as inputs if needed, and accepts optional get, diagonal_batch, diagonal_spatial arguments.