Stax – infinite networks (NNGP, NTK)

Closed-form NNGP and NTK library.

This library contains layer constructors mimicking those in jax.example_libraries.stax with similar API apart apart from:

1) Instead of (init_fn, apply_fn) tuple, layer constructors return a triple (init_fn, apply_fn, kernel_fn), where the added kernel_fn maps a Kernel to a new Kernel, and represents the change in the analytic NTK and NNGP kernels (Kernel.nngp, Kernel.ntk). These functions are chained / stacked together within the serial or parallel combinators, similarly to init_fn and apply_fn.

2) In layers with random weights, NTK parameterization is used by default (https://arxiv.org/abs/1806.07572, page 3). Standard parameterization (https://arxiv.org/abs/2001.07301) can be specified for Conv and Dense layers by a keyword argument parameterization.

3) Some functionality may be missing (e.g. BatchNorm), and some may be present only in our library (e.g. CIRCULAR padding, LayerNorm, GlobalAvgPool, GlobalSelfAttention, flexible batch and channel axes etc.).

Example

>>>  from jax import random
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  key1, key2 = random.split(random.PRNGKey(1), 2)
>>>  x_train = random.normal(key1, (20, 32, 32, 3))
>>>  y_train = random.uniform(key1, (20, 10))
>>>  x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.serial(
>>>      stax.Conv(128, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(256, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(512, (3, 3)),
>>>      stax.Flatten(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>>                                                        y_train)
>>>
>>>  # (5, 10) np.ndarray NNGP test prediction
>>>  y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>>
>>>  # (5, 10) np.ndarray NTK prediction
>>>  y_test_ntk = predict_fn(x_test=x_test, get='ntk')

Combinators

Layers to combine multiple other layers into one.

serial(*layers)

Combinator for composing layers in serial.

parallel(*layers)

Combinator for composing layers in parallel.

Branching

Layers to split outputs into many, or combine many into ones.

FanOut(num)

Layer construction function for a fan-out layer.

FanInConcat([axis])

Layer construction function for a fan-in concatenation layer.

FanInProd()

Layer construction function for a fan-in product layer.

FanInSum()

Layer construction function for a fan-in sum layer.

Linear parametric

Linear layers with trainable parameters.

Dense(out_dim[, W_std, b_std, ...])

Layer constructor function for a dense (fully-connected) layer.

Conv(out_chan, filter_shape[, strides, ...])

Layer construction function for a general convolution layer.

ConvLocal(out_chan, filter_shape[, strides, ...])

Layer construction function for a general unshared convolution layer.

ConvTranspose(out_chan, filter_shape[, ...])

Layer construction function for a general transpose convolution layer.

GlobalSelfAttention(n_chan_out, n_chan_key, ...)

Layer construction function for (global) scaled dot-product self-attention.

Linear nonparametric

Linear layers without any trainable parameters.

Aggregate([aggregate_axis, batch_axis, ...])

Layer constructor for aggregation operator (graphical neural network).

AvgPool(window_shape[, strides, padding, ...])

Layer construction function for an average pooling layer.

Identity()

Layer construction function for an identity layer.

DotGeneral(*[, lhs, rhs, dimension_numbers, ...])

Layer constructor for a constant (non-trainable) rhs/lhs Dot General.

Dropout(rate[, mode])

Dropout layer.

Flatten([batch_axis, batch_axis_out])

Layer construction function for flattening all non-batch dimensions.

GlobalAvgPool([batch_axis, channel_axis])

Layer construction function for a global average pooling layer.

GlobalSumPool([batch_axis, channel_axis])

Layer construction function for a global sum pooling layer.

ImageResize(shape, method[, antialias, ...])

Image resize function mimicking jax.image.resize.

LayerNorm([axis, eps, batch_axis, channel_axis])

Layer normalisation.

SumPool(window_shape[, strides, padding, ...])

Layer construction function for a 2D sum pooling layer.

Elementwise nonlinear

Pointwise nonlinear layers.

ABRelu(a, b[, do_stabilize])

ABReLU nonlinearity, i.e. a * min(x, 0) + b * max(x, 0).

Abs([do_stabilize])

Absolute value nonlinearity.

Cos([a, b, c])

Affine transform of Cos nonlinearity, i.e. a cos(b*x + c).

Elementwise([fn, nngp_fn, d_nngp_fn])

Elementwise application of fn using provided nngp_fn.

ElementwiseNumerical(fn, deg[, df])

Activation function using numerical integration.

Erf([a, b, c])

Affine transform of Erf nonlinearity, i.e. a * Erf(b * x) + c.

Exp([a, b])

Elementwise natural exponent function a * np.exp(b * x).

ExpNormalized([gamma, shift, do_clip])

Simulates the "Gaussian normalized kernel".

Gaussian([a, b])

Elementwise Gaussian function a * np.exp(b * x**2).

Gelu([approximate])

Gelu function.

Hermite(degree)

Hermite polynomials.

LeakyRelu(alpha[, do_stabilize])

Leaky ReLU nonlinearity, i.e. alpha * min(x, 0) + max(x, 0).

Rbf([gamma])

Dual activation function for normalized RBF or squared exponential kernel.

Relu([do_stabilize])

ReLU nonlinearity.

Sigmoid_like()

A sigmoid like function f(x) = .5 * erf(x / 2.4020563531719796) + .5.

Sign()

Sign function.

Sin([a, b, c])

Affine transform of Sin nonlinearity, i.e. a sin(b*x + c).

Helper enums

Enums for specifying layer properties. Strings can be used in their place.

Padding(value)

Type of padding in pooling and convolutional layers.

PositionalEmbedding(value)

Type of positional embeddings to use in a GlobalSelfAttention layer.

For developers

Classes and decorators helpful for constructing your own layers.

layer(layer_fn)

A convenience decorator to be added to all public layers like Relu etc.

supports_masking(remask_kernel)

Returns a decorator that turns layers into layers supporting masking.

requires(**static_reqs)

Returns a decorator that augments kernel_fn with consistency checks.

Bool(value)

Helper trinary logic class.

Diagonal([input, output])

Helps decide whether to allow the kernel to contain diagonal entries only.