Neural Tangents Reference

Neural Tangents is a set of tools for constructing and training infinitely wide neural networks.

Stax – infinite networks (NNGP, NTK)

Closed-form NNGP and NTK library.

This library contains layer constructors mimicking those in jax.example_libraries.stax with similar API apart apart from:

1) Instead of (init_fn, apply_fn) tuple, layer constructors return a triple (init_fn, apply_fn, kernel_fn), where the added kernel_fn maps a Kernel to a new Kernel, and represents the change in the analytic NTK and NNGP kernels (Kernel.nngp, Kernel.ntk). These functions are chained / stacked together within the serial or parallel combinators, similarly to init_fn and apply_fn.

2) In layers with random weights, NTK parameterization is used by default (https://arxiv.org/abs/1806.07572, page 3). Standard parameterization (https://arxiv.org/abs/2001.07301) can be specified for Conv and Dense layers by a keyword argument parameterization.

3) Some functionality may be missing (e.g. BatchNorm), and some may be present only in our library (e.g. CIRCULAR padding, LayerNorm, GlobalAvgPool, GlobalSelfAttention, flexible batch and channel axes etc.).

Example

>>>  from jax import random
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  key1, key2 = random.split(random.PRNGKey(1), 2)
>>>  x_train = random.normal(key1, (20, 32, 32, 3))
>>>  y_train = random.uniform(key1, (20, 10))
>>>  x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.serial(
>>>      stax.Conv(128, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(256, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(512, (3, 3)),
>>>      stax.Flatten(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>>                                                        y_train)
>>>
>>>  # (5, 10) np.ndarray NNGP test prediction
>>>  y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>>
>>>  # (5, 10) np.ndarray NTK prediction
>>>  y_test_ntk = predict_fn(x_test=x_test, get='ntk')

Combinators

Layers to combine multiple other layers into one.

serial(*layers)

Combinator for composing layers in serial.

parallel(*layers)

Combinator for composing layers in parallel.

Branching

Layers to split outputs into many, or combine many into ones.

FanOut(num)

Layer construction function for a fan-out layer.

FanInConcat([axis])

Layer construction function for a fan-in concatenation layer.

FanInProd()

Layer construction function for a fan-in product layer.

FanInSum()

Layer construction function for a fan-in sum layer.

Linear parametric

Linear layers with trainable parameters.

Dense(out_dim[, W_std, b_std, ...])

Layer constructor function for a dense (fully-connected) layer.

Conv(out_chan, filter_shape[, strides, ...])

Layer construction function for a general convolution layer.

ConvLocal(out_chan, filter_shape[, strides, ...])

Layer construction function for a general unshared convolution layer.

ConvTranspose(out_chan, filter_shape[, ...])

Layer construction function for a general transpose convolution layer.

GlobalSelfAttention(n_chan_out, n_chan_key, ...)

Layer construction function for (global) scaled dot-product self-attention.

Linear nonparametric

Linear layers without any trainable parameters.

Aggregate([aggregate_axis, batch_axis, ...])

Layer constructor for aggregation operator (graphical neural network).

AvgPool(window_shape[, strides, padding, ...])

Layer construction function for an average pooling layer.

Identity()

Layer construction function for an identity layer.

DotGeneral(*[, lhs, rhs, dimension_numbers, ...])

Layer constructor for a constant (non-trainable) rhs/lhs Dot General.

Dropout(rate[, mode])

Dropout layer.

Flatten([batch_axis, batch_axis_out])

Layer construction function for flattening all non-batch dimensions.

GlobalAvgPool([batch_axis, channel_axis])

Layer construction function for a global average pooling layer.

GlobalSumPool([batch_axis, channel_axis])

Layer construction function for a global sum pooling layer.

ImageResize(shape, method[, antialias, ...])

Image resize function mimicking jax.image.resize.

LayerNorm([axis, eps, batch_axis, channel_axis])

Layer normalisation.

SumPool(window_shape[, strides, padding, ...])

Layer construction function for a 2D sum pooling layer.

Elementwise nonlinear

Pointwise nonlinear layers.

ABRelu(a, b[, do_stabilize])

ABReLU nonlinearity, i.e. a * min(x, 0) + b * max(x, 0).

Abs([do_stabilize])

Absolute value nonlinearity.

Cos([a, b, c])

Affine transform of Cos nonlinearity, i.e. a cos(b*x + c).

Elementwise([fn, nngp_fn, d_nngp_fn])

Elementwise application of fn using provided nngp_fn.

ElementwiseNumerical(fn, deg[, df])

Activation function using numerical integration.

Erf([a, b, c])

Affine transform of Erf nonlinearity, i.e. a * Erf(b * x) + c.

Exp([a, b])

Elementwise natural exponent function a * np.exp(b * x).

ExpNormalized([gamma, shift, do_clip])

Simulates the "Gaussian normalized kernel".

Gaussian([a, b])

Elementwise Gaussian function a * np.exp(b * x**2).

Gelu([approximate])

Gelu function.

Hermite(degree)

Hermite polynomials.

LeakyRelu(alpha[, do_stabilize])

Leaky ReLU nonlinearity, i.e. alpha * min(x, 0) + max(x, 0).

Rbf([gamma])

Dual activation function for normalized RBF or squared exponential kernel.

Relu([do_stabilize])

ReLU nonlinearity.

Sigmoid_like()

A sigmoid like function f(x) = .5 * erf(x / 2.4020563531719796) + .5.

Sign()

Sign function.

Sin([a, b, c])

Affine transform of Sin nonlinearity, i.e. a sin(b*x + c).

Helper enums

Enums for specifying layer properties. Strings can be used in their place.

Padding(value)

Type of padding in pooling and convolutional layers.

PositionalEmbedding(value)

Type of positional embeddings to use in a GlobalSelfAttention layer.

For developers

Classes and decorators helpful for constructing your own layers.

layer(layer_fn)

A convenience decorator to be added to all public layers like Relu etc.

supports_masking(remask_kernel)

Returns a decorator that turns layers into layers supporting masking.

requires(**static_reqs)

Returns a decorator that augments kernel_fn with consistency checks.

Bool(value)

Helper trinary logic class.

Diagonal([input, output])

Helps decide whether to allow the kernel to contain diagonal entries only.

Empirical – finite NNGP and NTK

Compute empirical NNGP and NTK; approximate functions via Taylor series.

All functions in this module are applicable to any JAX functions of proper signatures (not only those from nt.stax).

NNGP and NTK are computed using nt.empirical_nngp_fn, nt.empirical_ntk_fn, or nt.empirical_kernel_fn (for both). The kernels have a very specific output shape convention that may be unexpected. Further, NTK has multiple implementations that may perform differently depending on the task. Please read individual functions’ docstrings.

Example

>>>  from jax import random
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>>  x_train = random.normal(key1, (20, 32, 32, 3))
>>>  y_train = random.uniform(key1, (20, 10))
>>>  x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>>  # A narrow CNN.
>>>  init_fn, f, _ = stax.serial(
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Flatten(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  _, params = init_fn(key3, x_train.shape)
>>>
>>>  # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>>  # network is iid along the batch axis, no BatchNorm. Use default
>>>  # `implementation=1` since the network has few trainable parameters.
>>>  kernel_fn = nt.empirical_kernel_fn(
>>>      f, trace_axes=(-1,), vmap_axes=0, implementation=1)
>>>
>>>  # (5, 20) np.ndarray test-train NNGP/NTK
>>>  nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>>  ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>>
>>>  # Full kernel: not reducing over logits.
>>>  kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0)
>>>
>>>  # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple.
>>>  k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>>  # A wide FCN with lots of parameters
>>>  init_fn, f, _ = stax.serial(
>>>      stax.Flatten(),
>>>      stax.Dense(1024),
>>>      stax.Relu(),
>>>      stax.Dense(1024),
>>>      stax.Relu(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  _, params = init_fn(key3, x_train.shape)
>>>
>>>  # Use implicit differentiation in NTK: `implementation=2` to reduce
>>>  # memory cost, since the network has many trainable parameters.
>>>  ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2)
>>>
>>>  # (5, 5) np.ndarray test-test NTK
>>>  ntk_test_test = ntk_fn(x_test, None, params)
>>>
>>>  # Compute only output variances:
>>>  nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>>
>>>  # (20,) np.ndarray train-train diagonal NNGP
>>>  nngp_train_train_diag = nngp_fn(x_train, None, params)

Kernel functions

Finite-width NNGP and/or NTK kernel functions.

empirical_kernel_fn(f[, trace_axes, ...])

Returns a function that computes single draws from NNGP and NT kernels.

empirical_nngp_fn(f[, trace_axes, diagonal_axes])

Returns a function to draw a single sample the NNGP of a given network f.

empirical_ntk_fn(f[, trace_axes, ...])

Returns a function to draw a single sample the NTK of a given network f.

Linearization and Taylor expansion

Decorators to Taylor-expand around function parameters.

linearize(f, params)

Returns a function f_lin, the first order taylor approximation to f.

taylor_expand(f, params, degree)

Returns a function f_tayl, Taylor approximation to f of order degree.

Predict – inference with NNGP and NTK or linearized networks

Functions to make predictions on the train/test set using NTK/NNGP.

Most functions in this module accept training data as inputs and return a new function predict_fn that computes predictions on the train set / given test set / timesteps.

WARNING: trace_axes parameter supplied to prediction functions must match the respective parameter supplied to the function used to compute the kernel. Namely, this is the same trace_axes used to compute the empirical kernel (utils/empirical.py; diagonal_axes must be ()), or channel_axis in the output of the top layer used to compute the closed-form kernel (stax.py; note that closed-form kernels currently only support a single channel_axis).

Prediction / inference functions

Functions to make train/test set predictions given NNGP/NTK kernels or the linearized function.

gradient_descent(loss, k_train_train, y_train)

Predicts the outcome of function space training using gradient descent.

gradient_descent_mse(k_train_train, y_train)

Predicts the outcome of function space gradient descent training on MSE.

gradient_descent_mse_ensemble(kernel_fn, ...)

Predicts the gaussian embedding induced by gradient descent on MSE loss.

gp_inference(k_train_train, y_train[, ...])

Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP.

Utilities

max_learning_rate(ntk_train_train[, ...])

Computes the maximal feasible learning rate for infinite width NNs.

Helper classes

Dataclasses and namedtuples used to return predictions.

Gaussian(mean, covariance)

A (mean, covariance) convenience namedtuple.

ODEState([fx_train, fx_test, qx_train, qx_test])

ODE state dataclass holding outputs and auxiliary variables.

Batching – using multiple devices

Batch kernel computations serially or in parallel.

This module contains a decorator batch that can be applied to any kernel_fn of signature kernel_fn(x1, x2, *args, **kwargs). The decorated function performs the same computation by batching over x1 and x2 and concatenating the result, allowing to both use multiple accelerators and stay within memory limits.

Note that you typically should not apply the jax.jit decorator to the resulting batched_kernel_fn, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to apply jax.jit to the input kernel_fn function, as it is JITted internally.

Example

>>>  from jax import numpy as np
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  # Define some kernel function.
>>>  _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1))
>>>
>>>  # Compute the kernel in batches, in parallel.
>>>  kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5)
>>>
>>>  # Generate dummy input data.
>>>  x1, x2 = np.ones((40, 10)), np.ones((80, 10))
>>>  kernel_fn_batched(x1, x2) == kernel_fn(x1, x2)  # True!
neural_tangents.batch(kernel_fn, batch_size=0, device_count=- 1, store_on_device=True)[source]

Returns a function that computes a kernel in batches over all devices.

Note that you typically should not apply the jax.jit decorator to the resulting batched_kernel_fn, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to apply jax.jit to the input kernel_fn function, as it is JITted internally.

Parameters
  • kernel_fn (TypeVar(_KernelFn, bound= Union[AnalyticKernelFn, EmpiricalKernelFn, EmpiricalGetKernelFn, MonteCarloKernelFn])) – A function that computes a kernel on two batches, kernel_fn(x1, x2, *args, **kwargs). Here x1 and x2 are np.ndarray`s of shapes `(n1,) + input_shape and (n2,) + input_shape. The kernel function should return a PyTree.

  • batch_size (int) – specifies the size of each batch that gets processed per physical device. Because we parallelize the computation over columns it should be the case that x1.shape[0] is divisible by device_count * batch_size and x2.shape[0] is divisible by batch_size.

  • device_count (int) – specifies the number of physical devices to be used. If device_count == -1 all devices are used. If device_count == 0, no device parallelism is used (a single default device is used).

  • store_on_device (bool) – specifies whether the output should be kept on device or brought back to CPU RAM as it is computed. Defaults to True. Set to False to store and concatenate results using CPU RAM, allowing to compute larger kernels.

Return type

TypeVar(_KernelFn, bound= Union[AnalyticKernelFn, EmpiricalKernelFn, EmpiricalGetKernelFn, MonteCarloKernelFn])

Returns

A new function with the same signature as kernel_fn that computes the kernel by batching over the dataset in parallel with the specified batch_size using device_count devices.

Monte Carlo Sampling

Function to compute Monte Carlo NNGP and NTK estimates.

This module contains a function monte_carlo_kernel_fn that allow to compute Monte Carlo estimates of NNGP and NTK kernels of arbitrary functions. For more details on how individual samples are computed, refer to utils/empirical.py.

Note that the monte_carlo_kernel_fn accepts arguments like batch_size, device_count, and store_on_device, and is appropriately batched / parallelized. You don’t need to apply the nt.batch or jax.jit decorators to it. Further, you do not need to apply jax.jit to the input apply_fn function, as the resulting empirical kernel function is JITted internally.

neural_tangents.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, batch_size=0, device_count=- 1, store_on_device=True, trace_axes=(- 1,), diagonal_axes=(), vmap_axes=None, implementation=1)[source]

Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.

Note that the returned function is appropriately batched / parallelized. You don’t need to apply the nt.batch or jax.jit decorators to it. Further, you do not need to apply jax.jit to the input apply_fn function, as the resulting empirical kernel function is JITted internally.

Parameters
  • init_fn (InitFn) – a function initializing parameters of the neural network. From jax.example_libraries.stax: “takes an rng key and an input shape and returns an (output_shape, params) pair”.

  • apply_fn (ApplyFn) – a function computing the output of the neural network. From jax.example_libraries.stax: “takes params, inputs, and an rng key and applies the layer”.

  • key (KeyArray) – RNG (jax.random.PRNGKey) for sampling random networks. Must have shape (2,).

  • n_samples (Union[int, Iterable[int]]) – number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: use n_samples=[2**k for k in range(10)] for the generator to yield estimates using 1, 2, 4, …, 512 Monte Carlo samples.

  • batch_size (int) – an integer making the kernel computed in batches of x1 and x2 of this size. 0 means computing the whole kernel. Must divide x1.shape[0] and x2.shape[0].

  • device_count (int) – an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores). -1 means use all available devices. 0 means compute on a single device sequentially. If not 0, must divide x1.shape[0].

  • store_on_device (bool) – a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Union[List[int], Tuple[int, ...], int, None, Tuple[Union[List[int], Tuple[int, ...], int, None], Union[List[int], Tuple[int, ...], int, None], Dict[str, Union[List[int], Tuple[int, ...], int, None]]]]) – applicable only to NTK. A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes, i.e. f can be evaluated as a vmap. This allows to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common usecase is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (int) –

    applicable only to NTK.

    1 or 2.

    1 directly instantiates Jacobians and computes their outer product.

    2 uses implicit differentiation to avoid instantiating whole Jacobians at once. The implicit kernel is derived by observing that: \(\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)\), i.e. a linear function \([J(X_1) J(X_2)^T]\) applied to an identity matrix \(I\). This allows the computation of the NTK to be phrased as: \(a(v) = J(X_2)^T v\), which is computed by a vector-Jacobian product; \(b(v) = J(X_1) a(v)\) which is computed by a Jacobian-vector product; and \(\Theta = [b(v)] / d[v^T](I)\) which is computed via a vmap of \(b(v)\) over columns of the identity matrix \(I\).

    It is best to benchmark each method on your specific task. We suggest using 1 unless you get OOMs due to large number of trainable parameters, otherwise - 2.

Return type

MonteCarloKernelFn

Returns

If n_samples is an integer, returns a function of signature kernel_fn(x1, x2, get) that returns an MC estimation of the kernel using n_samples. If n_samples is a collection of integers, kernel_fn(x1, x2, get) returns a generator that yields estimates using n samples for n in n_samples.

Example

>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>> init_fn, apply_fn, _ = stax.serial(
>>>     stax.Conv(128, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(256, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(512, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>>
>>> n_samples = 200
>>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
>>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
>>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
>>>
>>> n_samples = [1, 10, 100, 1000]
>>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
>>>                                                n_samples)
>>> kernel_samples = kernel_fn_generator(x_train, x_test,
>>>                                      get=('nngp', 'ntk'))
>>> for n, kernel in zip(n_samples, kernel_samples):
>>>   print(n, kernel)
>>>   # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.

Indices and tables