neural_tangents.stax.Index
- neural_tangents.stax.Index(idx, batch_axis=0, channel_axis=-1)[source]
Index into the array mimicking
numpy.ndarray
indexing.- Parameters:
idx (
Union
[int
,slice
,ellipsis
,tuple
[Union
[int
,slice
,ellipsis
],...
]]) – a slice object that would result from indexing an array as x[idx]. To create this object, use the helper objectSlice
, i.e. pass idx=stax.Slice[1:10, :, ::-1] (which is equivalent to passing an explicit idx=(slice(1, 10, None), slice(None), slice(None, None, -1).batch_axis (
int
) – batch axis for inputs. Defaults to 0, the leading axis.channel_axis (
int
) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).
- Raises:
NotImplementedError – If the channel_axis (infinite width) is indexed (except for : or …) in the kernel regime (kernel_fn).
NotImplementedError – If the batch_axis is indexed with an integer (as opposed to a tuple or slice) in the kernel regime (kernel_fn), since the library currently requires there always to be batch_axis in the kernel regime (while indexing with integers removes the respective axis).
ValueError – If init_fn is called on a shape with dummy axes (with sizes like -1 or None), that are indexed with non-trivial (not : or …) slices. For indexing, the size of the respective axis needs to be specified.
Example
>>> from neural_tangents import stax >>> # >>> init_fn, apply_fn, kernel_fn = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> # Select every other element from the batch (leading axis), cropped >>> # to the upper-left 4x4 corner. >>> stax.Index(idx=stax.Slice[::2, :4, :4]) >>> stax.Conv(128, (2, 2)), >>> stax.Relu(), >>> # Select the first row. Notice that the image becomes 1D. >>> stax.Index(idx=stax.Slice[:, 0, ...]) >>> stax.Conv(128, (2,)) >>> stax.GlobalAvgPool(), >>> stax.Dense(10) >>> )