neural_tangents.stax.ConvLocal
- neural_tangents.stax.ConvLocal(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=None, dimension_numbers=None, parameterization='ntk', s=(1, 1))[source]
General unshared convolution.
Also known and “Locally connected networks” or LCNs, these are equivalent to convolutions except for having separate (unshared) kernels at different spatial locations.
- Parameters:
out_chan (
int
) – The number of output channels / features of the convolution. This is ignored in by the kernel_fn in “ntk” parameterization.filter_shape (
Sequence
[int
]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.strides (
Optional
[Sequence
[int
]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.padding (
str
) – Specifies padding for the convolution. Can be one of “VALID”, “SAME”, or “CIRCULAR”. “CIRCULAR” uses periodic convolutions.W_std (
float
) – standard deviation of the weights.b_std (
Optional
[float
]) – standard deviation of the biases. None means no bias.dimension_numbers (
Optional
[tuple
[str
,str
,str
]]) – Specifies which axes should be convolved over. Should match the specification injax.lax.conv_general_dilated
.parameterization (
str
) – Either “ntk” or “standard”. These parameterizations are the direct analogues for convolution of the corresponding parameterizations forDense
layers.s (
tuple
[int
,int
]) – A tuple of integers, a direct convolutional analogue of the respective parameters for theDense
layer.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).