neural_tangents.stax.SumPool
- neural_tangents.stax.SumPool(window_shape, strides=None, padding='VALID', batch_axis=0, channel_axis=- 1)[source]
Layer construction function for a 2D sum pooling layer.
Based on jax.example_libraries.stax.SumPool.
- Parameters
window_shape (
Sequence
[int
]) – The number of pixels over which pooling is to be performed.strides (
Optional
[Sequence
[int
]]) – The stride of the pooling window. None corresponds to a stride of (1, …, 1).padding (
str
) – Can be VALID, SAME, or CIRCULAR padding. Here CIRCULAR uses periodic boundary conditions on the image.batch_axis (
int
) – Specifies the batch dimension. Defaults to 0, the leading axis.channel_axis (
int
) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.
- Return type
Tuple
[InitFn
,ApplyFn
,LayerKernelFn
,MaskFn
]- Returns
(init_fn, apply_fn, kernel_fn).