neural_tangents.stax.GlobalSumPool

neural_tangents.stax.GlobalSumPool(batch_axis=0, channel_axis=- 1)[source]

Layer construction function for a global sum pooling layer.

Sums over and removes (keepdims=False) all spatial dimensions, preserving the order of batch and channel axes.

Parameters
  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns

(init_fn, apply_fn, kernel_fn).