Stax – infinite networks (NNGP, NTK)
Closed-form NNGP and NTK library.
This library contains layer constructors mimicking those in
jax.example_libraries.stax
with similar API apart apart from:
1) Instead of (init_fn, apply_fn)
tuple, layer constructors return a triple
(init_fn, apply_fn, kernel_fn)
, where the added kernel_fn
maps a
Kernel
to a new Kernel
, and represents the change in the
analytic NTK and NNGP kernels (Kernel.nngp
, Kernel.ntk
). These functions
are chained / stacked together within the serial
or parallel
combinators, similarly to init_fn
and apply_fn
.
2) In layers with random weights, NTK parameterization is used by default
(https://arxiv.org/abs/1806.07572, page 3). Standard parameterization
(https://arxiv.org/abs/2001.07301) can be specified for Conv
and Dense
layers by a keyword argument parameterization
.
3) Some functionality may be missing (e.g. BatchNorm
), and some may be
present only in our library (e.g. CIRCULAR
padding, LayerNorm
,
GlobalAvgPool
, GlobalSelfAttention
, flexible batch and channel axes etc.).
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(256, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(512, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>>
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>> y_train)
>>>
>>> # (5, 10) np.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>>
>>> # (5, 10) np.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')
Combinators
Layers to combine multiple other layers into one.
|
Combinator for composing layers in serial. |
|
Combinator for composing layers in parallel. |
Branching
Layers to split outputs into many, or combine many into ones.
|
Layer construction function for a fan-out layer. |
|
Layer construction function for a fan-in concatenation layer. |
Layer construction function for a fan-in product layer. |
|
|
Layer construction function for a fan-in sum layer. |
Linear parametric
Linear layers with trainable parameters.
|
Layer constructor function for a dense (fully-connected) layer. |
|
Layer construction function for a general convolution layer. |
|
Layer construction function for a general unshared convolution layer. |
|
Layer construction function for a general transpose convolution layer. |
|
Layer construction function for (global) scaled dot-product self-attention. |
Linear nonparametric
Linear layers without any trainable parameters.
|
Layer constructor for aggregation operator (graphical neural network). |
|
Layer construction function for an average pooling layer. |
|
Layer construction function for an identity layer. |
|
Layer constructor for a constant (non-trainable) rhs/lhs Dot General. |
|
Dropout layer. |
|
Layer construction function for flattening all non-batch dimensions. |
|
Layer construction function for a global average pooling layer. |
|
Layer construction function for a global sum pooling layer. |
|
Image resize function mimicking |
|
Layer normalisation. |
|
Layer construction function for a 2D sum pooling layer. |
Elementwise nonlinear
Pointwise nonlinear layers.
|
ABReLU nonlinearity, i.e. |
|
Absolute value nonlinearity. |
|
Affine transform of |
|
Elementwise application of |
|
Activation function using numerical integration. |
|
Affine transform of |
|
Elementwise natural exponent function |
|
Simulates the "Gaussian normalized kernel". |
|
Elementwise Gaussian function |
|
Gelu function. |
|
Hermite polynomials. |
|
Leaky ReLU nonlinearity, i.e. |
|
Dual activation function for normalized RBF or squared exponential kernel. |
|
ReLU nonlinearity. |
A sigmoid like function |
|
|
Sign function. |
|
Affine transform of |
Helper enums
Enums for specifying layer properties. Strings can be used in their place.
|
Type of padding in pooling and convolutional layers. |
|
Type of positional embeddings to use in a |
For developers
Classes and decorators helpful for constructing your own layers.
|
A convenience decorator to be added to all public layers like |
|
Returns a decorator that turns layers into layers supporting masking. |
|
Returns a decorator that augments |
|
Helper trinary logic class. |
|
Helps decide whether to allow the kernel to contain diagonal entries only. |