neural_tangents.stax.Conv

neural_tangents.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=None, dimension_numbers=None, parameterization='ntk')[source]

Layer construction function for a general convolution layer.

Based on jax.example_libraries.stax.GeneralConv.

Parameters
  • out_chan (int) – The number of output channels / features of the convolution. This is ignored in by the kernel_fn in NTK parameterization.

  • filter_shape (Sequence[int]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • strides (Optional[Sequence[int]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • padding (str) – Specifies padding for the convolution. Can be one of “VALID”, “SAME”, or “CIRCULAR”. “CIRCULAR” uses periodic convolutions.

  • W_std (float) – The standard deviation of the weights.

  • b_std (Optional[float]) – The standard deviation of the biases.

  • dimension_numbers (Optional[Tuple[str, str, str]]) – Specifies which axes should be convolved over. Should match the specification in jax.lax.conv_general_dilated.

  • parameterization (str) – Either “ntk” or “standard”. These parameterizations are the direct analogues for convolution of the corresponding parameterizations for Dense layers.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns

(init_fn, apply_fn, kernel_fn).