neural_tangents.stax.layer
- neural_tangents.stax.layer(layer_fn)[source]
A convenience decorator to be added to all public layers.
Used in
Relu
etc.Makes the
kernel_fn
of the layer work with both inputjax.numpy.ndarray
(when the layer is the first one applied to inputs), and withKernel
for intermediary layers. Also adds optional arguments to thekernel_fn
to allow specifying the computation and returned results with more flexibility.- Parameters:
layer_fn (
Callable
[...
,tuple
[InitFn
,ApplyFn
,LayerKernelFn
]]) – Layer function returning triple (init_fn, apply_fn, kernel_fn).- Return type:
- Returns:
A function with the same signature as layer with kernel_fn now accepting
jax.numpy.ndarray
as inputs if needed, and accepts optional get, diagonal_batch, diagonal_spatial arguments.