Neural Tangents Reference

Neural Tangents (neural_tangentsnt) is a set of tools for constructing and training infinitely wide neural networks (a.k.a. NTK, NNGP).

nt.stax – infinite NNGP and NTK

Closed-form NNGP and NTK library.

This library contains layers mimicking those in jax.example_libraries.stax with similar API apart from:

1) Instead of (init_fn, apply_fn) tuple, layers return a triple (init_fn, apply_fn, kernel_fn), where the added kernel_fn maps a Kernel to a new Kernel, and represents the change in the analytic NTK and NNGP kernels (nngp, ntk). These functions are chained / stacked together within the serial or parallel combinators, similarly to init_fn and apply_fn. For details, please see “Neural Tangents: Fast and Easy Infinite Neural Networks in Python”.

2) In layers with random weights, NTK parameterization is used by default (see page 3 in “Neural Tangent Kernel: Convergence and Generalization in Neural Networks”). Standard parameterization can be specified for Conv and Dense layers by a keyword argument parameterization. For details, please see “On the infinite width limit of neural networks with a standard parameterization”.

3) Some functionality may be missing (e.g. jax.example_libraries.stax.BatchNorm), and some may be present only in our library (e.g. CIRCULAR padding, LayerNorm, GlobalAvgPool, GlobalSelfAttention, flexible batch and channel axes etc.).

Example

>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>>     stax.Conv(128, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(256, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(512, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>> #
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>>                                                       y_train)
>>> #
>>> # (5, 10) jnp.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>> #
>>> # (5, 10) jnp.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')

Combinators

Layers to combine multiple other layers into one.

parallel(*layers)

Combinator for composing layers in parallel.

repeat(layer, n)

Compose layer in a compiled loop n times.

serial(*layers)

Combinator for composing layers in serial.

Branching

Layers to split outputs into many, or combine many into ones.

FanInConcat([axis])

Fan-in concatenation.

FanInProd()

Fan-in product.

FanInSum()

Fan-in sum.

FanOut(num)

Fan-out.

Linear parametric

Linear layers with trainable parameters.

Conv(out_chan, filter_shape[, strides, ...])

General convolution.

ConvLocal(out_chan, filter_shape[, strides, ...])

General unshared convolution.

ConvTranspose(out_chan, filter_shape[, ...])

General transpose convolution.

Dense(out_dim[, W_std, b_std, batch_axis, ...])

Dense (fully-connected, matrix product).

GlobalSelfAttention(n_chan_out, n_chan_key, ...)

Global scaled dot-product self-attention.

Linear nonparametric

Linear layers without any trainable parameters.

Aggregate([aggregate_axis, batch_axis, ...])

Aggregation operator (graphical neural network).

AvgPool(window_shape[, strides, padding, ...])

Average pooling.

DotGeneral(*[, lhs, rhs, dimension_numbers, ...])

Constant (non-trainable) rhs/lhs Dot General.

Dropout(rate[, mode])

Dropout.

Flatten([batch_axis, batch_axis_out])

Flattening all non-batch dimensions.

GlobalAvgPool([batch_axis, channel_axis])

Global average pooling.

GlobalSumPool([batch_axis, channel_axis])

Global sum pooling.

Identity()

Identity (no-op).

ImageResize(shape, method[, antialias, ...])

Image resize function mimicking jax.image.resize.

Index(idx[, batch_axis, channel_axis])

Index into the array mimicking numpy.ndarray indexing.

LayerNorm([axis, eps, batch_axis, channel_axis])

Layer normalisation.

SumPool(window_shape[, strides, padding, ...])

Sum pooling.

Elementwise nonlinear

Pointwise nonlinear layers. For details, please see “Fast Neural Kernel Embeddings for General Activations”.

ABRelu(a, b[, do_stabilize])

ABReLU nonlinearity, i.e. a * min(x, 0) + b * max(x, 0).

Abs([do_stabilize])

Absolute value nonlinearity.

Cos([a, b, c])

Affine transform of Cos nonlinearity, i.e. a cos(b*x + c).

Elementwise([fn, nngp_fn, d_nngp_fn])

Elementwise application of fn using provided nngp_fn.

ElementwiseNumerical(fn, deg[, df])

Activation function using numerical integration.

Erf([a, b, c])

Affine transform of Erf nonlinearity, i.e. a * Erf(b * x) + c.

Exp([a, b])

Elementwise natural exponent function a * jnp.exp(b * x).

ExpNormalized([gamma, shift, do_clip])

Simulates the "Gaussian normalized kernel".

Gabor()

Gabor function exp(-x^2) * sin(x).

Gaussian([a, b])

Elementwise Gaussian function a * jnp.exp(b * x**2).

Gelu([approximate])

Gelu function.

Hermite(degree)

Hermite polynomials.

LeakyRelu(alpha[, do_stabilize])

Leaky ReLU nonlinearity, i.e. alpha * min(x, 0) + max(x, 0).

Monomial(degree)

Monomials, i.e. x^degree.

Polynomial(coef)

Polynomials, i.e. coef[0] + coef[1] * x + + coef[n] * x**n.

Rbf([gamma])

Dual activation function for normalized RBF or squared exponential kernel.

RectifiedMonomial(degree)

Rectified monomials, i.e. (x >= 0) * x^degree.

Relu([do_stabilize])

ReLU nonlinearity.

Sigmoid_like()

A sigmoid like function f(x) = .5 * erf(x / 2.4020563531719796) + .5.

Sign()

Sign function.

Sin([a, b, c])

Affine transform of Sin nonlinearity, i.e. a sin(b*x + c).

Helper classes

Utility classes for specifying layer properties. For enums, strings can be passed in their place.

AggregateImplementation(value)

Implementation of the Aggregate layer.

AttentionMechanism(value)

Type of nonlinearity to use in a GlobalSelfAttention layer.

Padding(value)

Type of padding in pooling and convolutional layers.

PositionalEmbedding(value)

Type of positional embeddings to use in a GlobalSelfAttention layer.

Slice

For developers

Classes and decorators helpful for constructing your own layers.

Bool(value)

Helper trinary logic class.

Diagonal([input, output])

Helps decide whether to allow the kernel to contain diagonal entries only.

layer(layer_fn)

A convenience decorator to be added to all public layers.

requires(**static_reqs)

Returns a decorator that augments kernel_fn with consistency checks.

supports_masking(remask_kernel)

Returns a decorator that turns layers into layers supporting masking.

nt.empirical – finite NNGP and NTK

Compute empirical NNGP and NTK; approximate functions via Taylor series.

All functions in this module are applicable to any JAX functions of proper signatures (not only those from stax).

NNGP and NTK are computed using empirical_nngp_fn, empirical_ntk_fn, or empirical_kernel_fn (for both). The kernels have a very specific output shape convention that may be unexpected. Further, NTK has multiple implementations that may perform differently depending on the task. Please read individual functions’ docstrings.

For details, please see “Fast Finite Width Neural Tangent Kernel”.

Example

>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> # A narrow CNN.
>>> init_fn, f, _ = stax.serial(
>>>     stax.Conv(32, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(32, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(32, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>> #
>>> _, params = init_fn(key3, x_train.shape)
>>> #
>>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>> # network is iid along the batch axis, no BatchNorm. Use default
>>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`).
>>> kernel_fn = nt.empirical_kernel_fn(
>>>     f, trace_axes=(-1,), vmap_axes=0,
>>>     implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
>>> #
>>> # (5, 20) jnp.ndarray test-train NNGP/NTK
>>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>> #
>>> # Full kernel: not reducing over logits. Use structured derivatives
>>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for
>>> # typically faster computation and lower memory cost.
>>> kernel_fn = nt.empirical_kernel_fn(
>>>     f, trace_axes=(), vmap_axes=0,
>>>     implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
>>> #
>>> # (5, 20, 10, 10) jnp.ndarray test-train NNGP/NTK namedtuple.
>>> k_test_train = kernel_fn(x_test, x_train, None, params)
>>> #
>>> # A wide FCN with lots of parameters and many (`100`) outputs.
>>> init_fn, f, _ = stax.serial(
>>>     stax.Flatten(),
>>>     stax.Dense(1024),
>>>     stax.Relu(),
>>>     stax.Dense(1024),
>>>     stax.Relu(),
>>>     stax.Dense(100)
>>> )
>>> #
>>> _, params = init_fn(key3, x_train.shape)
>>> #
>>> # Use ntk-vector products
>>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the
>>> # network has many parameters relative to the cost of forward pass,
>>> # large outputs.
>>> ntk_fn = nt.empirical_ntk_fn(
>>>     f, vmap_axes=0,
>>>     implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
>>> #
>>> # (5, 5) jnp.ndarray test-test NTK
>>> ntk_test_test = ntk_fn(x_test, None, params)
>>> #
>>> # Compute only output variances:
>>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>> #
>>> # (20,) jnp.ndarray train-train diagonal NNGP
>>> nngp_train_train_diag = nngp_fn(x_train, None, params)

Kernel functions

Finite-width NNGP and/or NTK kernel functions.

empirical_kernel_fn(f[, trace_axes, ...])

Returns a function that computes single draws from NNGP and NT kernels.

empirical_nngp_fn(f[, trace_axes, diagonal_axes])

Returns a function to draw a single sample the NNGP of a given network f.

empirical_ntk_fn(f[, trace_axes, ...])

Returns a function to draw a single sample the NTK of a given network f.

NTK implementation

An enum.IntEnum specifying NTK implementation method.

class neural_tangents.NtkImplementation(value)[source]

Implementation method of the underlying finite width NTK computation.

Below is a very brief summary of each method. For details, please see “Fast Finite Width Neural Tangent Kernel”.

AUTO

(or 0) evaluates FLOPs of all other methods at compilation time, and selects the fastest method. However, at the time it only works correctly on TPUs, and on CPU/GPU can return wrong results, which is why it is not the default. TODO(romann): revisit based on http://b/202218145.

JACOBIAN_CONTRACTION

(or 1) computes the NTK as the outer product of two Jacobians, each computed using reverse-mode Autodiff (vector-Jacobian products, VJPs). When JITted, the contraction is performed in a layerwise fashion, so that entire Jacobians aren’t necessarily instantiated in memory at once, and the memory usage of the method can be lower than memory needed to instantiate the two Jacobians. This method is best suited for networks with small outputs (such as scalar outputs for binary classification or regression, as opposed to 1000 ImageNet classes), and an expensive forward pass relative to the number of parameters (such as CNNs, where forward pass reuses a small filter bank many times). It is also the the most reliable method, since its implementation is simplest, and reverse-mode Autodiff is most commonly used and well tested elsewhere. For this reason it is set as the default.

NTK_VECTOR_PRODUCTS

(or 2) computes the NTK as a sequence of NTK-vector products, similarly to how a Jacobian is computed as a sequence of Jacobian-vector products (JVPs) or vector-Jacobian products (VJPs). This amounts to using both forward (JVPs) and reverse (VJPs) mode Autodiff, and allows to eliminate the Jacobian contraction at the expense of additional forward passes. Therefore this method is recommended for networks with a cheap forward pass relative to the number of parameters (e.g. fully-connected networks, where each parameter matrix is used only once in the forward pass), and networks with large outputs (e.g. 1000 ImageNet classes). Memory requirements of this method are same as JACOBIAN_CONTRACTION (1). Due to reliance of forward-mode Autodiff, this method is slightly more prone to JAX and XLA bugs than JACOBIAN_CONTRACTION (1), but overall is quite simple and reliable.

STRUCTURED_DERIVATIVES

(or 3) uses a custom JAX interpreter to compute the NTK more efficiently than other methods. It traverses the computational graph of a function in the same order as during reverse-mode Autodiff, but instead of computing VJPs, it directly computes MJJMPs, “matrix-Jacobian-Jacobian-matrix” products, which arise in the computation of an NTK. Each MJJMP computation relies on the structure in the Jacobians, hence the name. This method can be dramatically faster (up to several orders of magnitude) then other methods on fully-connected networks, and is usually faster or equivalent on CNNs, Transformers, and other architectures, but exact speedup (e.g. from no speedup to 10X) depends on each specific setting. It can also use less memory than other methods. In our experience it consistently outperforms other methods in most settings. However, its implementation is significantly more complex (hence bug-prone), and it doesn’t yet support functions using more exotic JAX primitives (e.g. jax.checkpoint, parallel collectives such as jax.lax.psum, compiled loops like jax.lax.scan, etc.), which is why it is highly-recommended to try, but not set as the default yet.

NTK-vector products

A function to compute NTK-vector products without instantiating the NTK.

empirical_ntk_vp_fn(f, x1, x2, params, ...)

Returns an NTK-vector product function.

Linearization and Taylor expansion

Decorators to Taylor-expand around function parameters.

linearize(f, params)

Returns a function f_lin, the first order taylor approximation to f.

taylor_expand(f, params, degree)

Returns a function f_tayl, Taylor approximation to f of order degree.

nt.predict – inference w/ NNGP & NTK

Functions to make predictions on the train/test set using NTK/NNGP.

Most functions in this module accept training data as inputs and return a new function predict_fn that computes predictions on the train set / given test set / timesteps.

Warning

trace_axes parameter supplied to prediction functions must match the respective parameter supplied to the function used to compute the kernel. Namely, this is the same trace_axes used to compute the empirical kernel (utils/empirical.py; diagonal_axes must be ()), or channel_axis in the output of the top layer used to compute the closed-form kernel (stax.py; note that closed-form kernels currently only support a single channel_axis).

Prediction / inference functions

Functions to make train/test set predictions given NNGP/NTK kernels or the linearized function.

gp_inference(k_train_train, y_train[, ...])

Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP.

gradient_descent(loss, k_train_train, y_train)

Predicts the outcome of function space training using gradient descent.

gradient_descent_mse(k_train_train, y_train)

Predicts the outcome of function space gradient descent training on MSE.

gradient_descent_mse_ensemble(kernel_fn, ...)

Predicts the gaussian embedding induced by gradient descent on MSE loss.

Utilities

max_learning_rate(ntk_train_train[, ...])

Computes the maximal feasible learning rate for infinite width NNs.

Helper classes

Dataclasses and namedtuples used to return predictions.

Gaussian(mean, covariance)

A (mean, covariance) convenience namedtuple.

ODEState([fx_train, fx_test, qx_train, qx_test])

ODE state dataclass holding outputs and auxiliary variables.

nt.batch – using multiple devices

Batch kernel computations serially or in parallel.

This module contains a decorator batch that can be applied to any kernel_fn of signature kernel_fn(x1, x2, *args, **kwargs). The decorated function performs the same computation by batching over x1 and x2 and concatenating the result, allowing to both use multiple accelerators and stay within memory limits.

Note that you typically should not apply the jax.jit decorator to the resulting batched_kernel_fn, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to apply jax.jit to the input kernel_fn function, as it is JITted internally.

Example

>>> from jax import numpy as jnp
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> # Define some kernel function.
>>> _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1))
>>> #
>>> # Compute the kernel in batches, in parallel.
>>> kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5)
>>> #
>>> # Generate dummy input data.
>>> x1, x2 = jnp.ones((40, 10)), jnp.ones((80, 10))
>>> kernel_fn_batched(x1, x2) == kernel_fn(x1, x2)  # True!
neural_tangents.batch(kernel_fn, batch_size=0, device_count=-1, store_on_device=True)[source]

Returns a function that computes a kernel in batches over all devices.

Note that you typically should not apply the jax.jit decorator to the resulting batched_kernel_fn, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to apply jax.jit to the input kernel_fn function, as it is JITted internally.

Parameters:
  • kernel_fn (TypeVar(_KernelFn, bound= Union[AnalyticKernelFn, EmpiricalKernelFn, EmpiricalGetKernelFn, MonteCarloKernelFn])) – A function that computes a kernel on two batches, kernel_fn(x1, x2, *args, **kwargs). Here x1 and x2 are jnp.ndarray`s of shapes `(n1,) + input_shape and (n2,) + input_shape. The kernel function should return a PyTree.

  • batch_size (int) – specifies the size of each batch that gets processed per physical device. Because we parallelize the computation over columns it should be the case that x1.shape[0] is divisible by device_count * batch_size and x2.shape[0] is divisible by batch_size.

  • device_count (int) – specifies the number of physical devices to be used. If device_count == -1 all devices are used. If device_count == 0, no device parallelism is used (a single default device is used).

  • store_on_device (bool) – specifies whether the output should be kept on device or brought back to CPU RAM as it is computed. Defaults to True. Set to False to store and concatenate results using CPU RAM, allowing to compute larger kernels.

Return type:

TypeVar(_KernelFn, bound= Union[AnalyticKernelFn, EmpiricalKernelFn, EmpiricalGetKernelFn, MonteCarloKernelFn])

Returns:

A new function with the same signature as kernel_fn that computes the kernel by batching over the dataset in parallel with the specified batch_size using device_count devices.

nt.monte_carlo_kernel_fn - MC Sampling

Function to compute Monte Carlo NNGP and NTK estimates.

This module contains a function monte_carlo_kernel_fn that allow to compute Monte Carlo estimates of NNGP and NTK kernels of arbitrary functions. For more details on how individual samples are computed, refer to utils/empirical.py.

Note that the monte_carlo_kernel_fn accepts arguments like batch_size, device_count, and store_on_device, and is appropriately batched / parallelized. You don’t need to apply the batch or jax.jit decorators to it. Further, you do not need to apply jax.jit to the input apply_fn function, as the resulting empirical kernel function is JITted internally.

neural_tangents.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, batch_size=0, device_count=-1, store_on_device=True, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]

Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.

Note that the returned function is appropriately batched / parallelized. You don’t need to apply the nt.batch or jax.jit decorators to it. Further, you do not need to apply jax.jit to the input apply_fn function, as the resulting empirical kernel function is JITted internally.

Parameters:
  • init_fn (InitFn) – a function initializing parameters of the neural network. From jax.example_libraries.stax: “takes an rng key and an input shape and returns an (output_shape, params) pair”.

  • apply_fn (ApplyFn) – a function computing the output of the neural network. From jax.example_libraries.stax: “takes params, inputs, and an rng key and applies the layer”.

  • key (Array) – RNG (jax.random.PRNGKey) for sampling random networks. Must have shape (2,).

  • n_samples (Union[int, Iterable[int]]) – number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: use n_samples=[2**k for k in range(10)] for the generator to yield estimates using 1, 2, 4, …, 512 Monte Carlo samples.

  • batch_size (int) – an integer making the kernel computed in batches of x1 and x2 of this size. 0 means computing the whole kernel. Must divide x1.shape[0] and x2.shape[0].

  • device_count (int) – an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores). -1 means use all available devices. 0 means compute on a single device sequentially. If not 0, must divide x1.shape[0].

  • store_on_device (bool) – a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Union[Any, None, tuple[Optional[Any], Optional[Any], dict[str, Optional[Any]]]]) – applicable only to NTK. A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes, i.e. f can be evaluated as a vmap. This allows to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common usecase is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (Union[int, NtkImplementation]) – Applicable only to NTK, an NtkImplementation value (or an int 0, 1, 2, or 3). See the NtkImplementation docstring for details.

  • _j_rules (bool) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow custom Jacobian rules for intermediary primitive dy/dw computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to False to use JVPs or VJPs, via JAX’s jax.jacfwd or jax.jacrev. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to False could improve performance.

  • _s_rules (bool) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow efficient MJJMp rules for structured dy/dw primitive Jacobians. In practice should be set to True, and setting it to False can lead to dramatic deterioration of performance.

  • _fwd (Optional[bool]) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow jax.jvp in intermediary primitive Jacobian dy/dw computations, False to always use jax.vjp. None to decide automatically based on input/output sizes. Applicable when _j_rules=False, or when a primitive does not have a Jacobian rule. Should be set to None for best performance.

Return type:

MonteCarloKernelFn

Returns:

If n_samples is an integer, returns a function of signature kernel_fn(x1, x2, get) that returns an MC estimation of the kernel using n_samples. If n_samples is a collection of integers, kernel_fn(x1, x2, get) returns a generator that yields estimates using n samples for n in n_samples.

Example

>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, _ = stax.serial(
>>>     stax.Conv(128, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(256, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(512, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>> #
>>> n_samples = 200
>>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
>>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
>>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
>>> #
>>> n_samples = [1, 10, 100, 1000]
>>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
>>>                                                n_samples)
>>> kernel_samples = kernel_fn_generator(x_train, x_test,
>>>                                      get=('nngp', 'ntk'))
>>> for n, kernel in zip(n_samples, kernel_samples):
>>>   print(n, kernel)
>>>   # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.

nt.experimental – prototypes

Warning

This module contains new highly-experimental prototypes. Please beware that they are not properly tested, not supported, and may suffer from sub-optimal performance. Use at your own risk!

Kernel functions

Finite-width NTK kernel function in Tensorflow. See the Python and Colab usage examples.

neural_tangents.experimental.empirical_ntk_fn_tf(f, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]

Returns a function to draw a single sample the NTK of a given network f.

This function follows the API of neural_tangents.empirical_ntk_fn, but is applicable to Tensorflow tf.Module, tf.keras.Model, or tf.function, via a TF->JAX->TF roundtrip using tf2jax and jax2tf. Docstring below adapted from neural_tangents.empirical_ntk_fn.

Warning

This function is experimental and risks returning wrong results or performing slowly. It is intended to demonstrate the usage of neural_tangents.empirical_ntk_fn in Tensorflow, but has not been extensively tested. Specifically, it appears to have very long compile times (but OK runtime), is prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model.

TODO(romann): support division between trainable and non-trainable variables.

TODO(romann): investigate slow compile times.

Parameters:
  • f (Union[Module, PolymorphicFunction]) –

    tf.Module or tf.function whose NTK we are computing. Must satisfy the following:

    • if a tf.function, must have the signature of f(params, x).

    • if a tf.Module, must be either a tf.keras.Model, or be callable.

    • input signature (f.input_shape for tf.Module or tf.keras.Model, or f.input_signature for tf.function) must be known.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Union[Any, None, tuple[Optional[Any], Optional[Any], dict[str, Optional[Any]]]]) –

    A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f.call(x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes. In other words, it certifies that f can be evaluated as a vmap with out_axes=out_axes over x (along in_axes) and those arguments in **kwargs that are present in kwargs_axes.keys() (along kwargs_axes.values()).

    This allows us to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (Union[NtkImplementation, int]) – An NtkImplementation value (or an int 0, 1, 2, or 3). See the NtkImplementation docstring for details.

  • _j_rules (bool) – Internal debugging parameter, applicable only when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow custom Jacobian rules for intermediary primitive dy/dw computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to False to use JVPs or VJPs, via JAX’s jax.jacfwd or jax.jacrev. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to False could improve performance.

  • _s_rules (bool) – Internal debugging parameter, applicable only when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow efficient MJJMp rules for structured dy/dw primitive Jacobians. In practice should be set to True, and setting it to False can lead to dramatic deterioration of performance.

  • _fwd (Optional[bool]) – Internal debugging parameter, applicable only when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow jax.jvp in intermediary primitive Jacobian dy/dw computations, False to always use jax.vjp. None to decide automatically based on input/output sizes. Applicable when _j_rules=False, or when a primitive does not have a Jacobian rule. Should be set to None for best performance.

Return type:

Callable[..., Any]

Returns:

A function ntk_fn that computes the empirical ntk.

Helper functions

A helper function to convert Tensorflow stateful models into functional-style, stateless apply_fn(params, x) forward pass function and extract the respective params.

neural_tangents.experimental.get_apply_fn_and_params(f)[source]

Converts a tf.Module into a forward-pass apply_fn and params.

Use this function to extract params to pass to the Tensorflow empirical NTK kernel function.

Warning

This function does not distinguish between trainable and non-trainable parameters of the model.

Parameters:

f (Module) – a tf.Module to convert to a apply_fn(params, x) function. Must have an input_shape attribute set (specifying shape of x), and be callable or be a tf.keras.Model.

Returns:

A tuple fo (apply_fn, params), where params is a PyTree[tf.Tensor].

Kernel dataclass

class neural_tangents.Kernel(nngp, ntk, cov1, cov2, x1_is_x2, is_gaussian, is_reversed, is_input, diagonal_batch, diagonal_spatial, shape1, shape2, batch_axis, channel_axis, mask1=None, mask2=None)[source]

Dataclass containing information about the NTK and NNGP of a model.

nngp

covariance between the first and second batches (NNGP). A jnp.ndarray of shape (batch_size_1, batch_size_2, height, [height,], width, [width,], …)), where exact shape depends on diagonal_spatial.

ntk

the neural tangent kernel (NTK). jnp.ndarray of same shape as nngp.

cov1

covariance of the first batch of inputs. A jnp.ndarray with shape (batch_size_1, [batch_size_1,] height, [height,], width, [width,], …) where exact shape depends on diagonal_batch and diagonal_spatial.

cov2

optional covariance of the second batch of inputs. A jnp.ndarray with shape (batch_size_2, [batch_size_2,] height, [height,], width, [width,], …) where the exact shape depends on diagonal_batch and diagonal_spatial.

x1_is_x2

a boolean specifying whether x1 and x2 are the same.

is_gaussian

a boolean, specifying whether the output features or channels of the layer / NN function (returning this Kernel as the kernel_fn) are i.i.d. Gaussian with covariance nngp, conditioned on fixed inputs to the layer and i.i.d. Gaussian weights and biases of the layer. For example, passing an input through a CNN layer with i.i.d. Gaussian weights and biases produces i.i.d. Gaussian random variables along the channel dimension, while passing an input through a nonlinearity does not.

is_reversed

a boolean specifying whether the covariance matrices nngp, cov1, cov2, and ntk have the ordering of spatial dimensions reversed. Ignored unless diagonal_spatial is False. Used internally to avoid self-cancelling transpositions in a sequence of CNN layers that flip the order of kernel spatial dimensions.

is_input

a boolean specifying whether the current layer is the input layer, and it is used to avoid applying dropout to the input layer.

diagonal_batch

a boolean specifying whether cov1 and cov2 store only the diagonal of the sample-sample covariance (diagonal_batch == True, cov1.shape == (batch_size_1, …)), or the full covariance (diagonal_batch == False, cov1.shape == (batch_size_1, batch_size_1, …)). Defaults to True as no current layers require the full covariance.

diagonal_spatial

a boolean specifying whether all (cov1, ntk, etc.) covariance matrices store only the diagonals of the location-location covariances (diagonal_spatial == True, nngp.shape == (batch_size_1, batch_size_2, height, width, depth, …)), or the full covariance (diagonal_spatial == False, nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, …)). Defaults to False, but is set to True if the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers and Flatten on top).

shape1

a tuple specifying the shape of the random variable in the first batch of inputs. These have covariance cov1 and covariance with the second batch of inputs given by nngp.

shape2

a tuple specifying the shape of the random variable in the second batch of inputs. These have covariance cov2 and covariance with the first batch of inputs given by nngp.

batch_axis

the batch axis of the activations.

channel_axis

channel axis of the activations (taken to infinity).

mask1

an optional boolean jnp.ndarray with a shape broadcastable to shape1 (and the same number of dimensions). True stands for the input being masked at that position, while False means the input is visible. For example, if shape1 == (5, 32, 32, 3) (a batch of 5 NHWC CIFAR10 images), a mask1 of shape (5, 1, 32, 1) means different images can have different blocked columns (H and C dimensions are always either both blocked or unblocked). None means no masking.

mask2

same as mask1, but for the second batch of inputs.

asdict(*, dict_factory=<class 'dict'>)

Instance method alternative to dataclasses.asdict.

astuple(*, tuple_factory=<class 'tuple'>)

Instance method alternative to dataclasses.astuple.

dot_general(other1, other2, is_lhs, dimension_numbers)[source]

Covariances of jax.lax.dot_general of x1/2 with other1/2.

Return type:

Kernel

mask(mask1, mask2)[source]

Mask all covariance matrices according to mask1, mask2.

Return type:

Kernel

replace(**changes)

Instance method alternative to dataclasses.replace.

reverse()[source]

Reverse the order of spatial axes in the covariance matrices.

Return type:

Kernel

Returns:

A Kernel object with spatial axes order flipped in all covariance matrices. For example, if kernel.nngp has shape (batch_size_1, batch_size_2, H, H, W, W, D, D, …), then reverse(kernels).nngp has shape (batch_size_1, batch_size_2, …, D, D, W, W, H, H).

transpose(axes=None)[source]

Permute spatial dimensions of the Kernel according to axes.

Follows https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html

Note that axes apply only to spatial axes, batch axes are ignored and remain leading in all covariance arrays, and channel axes are not present in a Kernel object. If the covariance array is of shape (batch_size, X, X, Y, Y), and axes == (0, 1), resulting array is of shape (batch_size, Y, Y, X, X).

Return type:

Kernel

Typing

Common Type Definitions.

class neural_tangents._src.utils.typing.AnalyticKernelFn(*args, **kwargs)[source]

A type alias for analytic kernel functions.

A kernel function that computes an analytic kernel. Takes either a Kernel or jax.numpy.ndarray inputs and a get argument that specifies what quantities should be computed by the kernel. Returns either a Kernel object or jax.numpy.ndarray-s for kernels specified by get.

class neural_tangents._src.utils.typing.ApplyFn(*args, **kwargs)[source]

A type alias for apply functions.

Apply functions do computations with finite-width neural networks. They are functions that take a PyTree of parameters and an array of inputs and produce an array of outputs.

neural_tangents._src.utils.typing.Axes

Axes specification, can be integers (axis=-1) or sequences (axis=(1, 3)).

alias of Union[int, Sequence[int]]

class neural_tangents._src.utils.typing.EmpiricalGetKernelFn(*args, **kwargs)[source]

A type alias for empirical kernel functions accepting a get argument.

A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.

Equivalent to EmpiricalKernelFn, but accepts a get argument, which can be for example get=("nngp", "ntk"), to compute both kernels together.

class neural_tangents._src.utils.typing.EmpiricalKernelFn(*args, **kwargs)[source]

A type alias for empirical kernel functions computing either NTK or NNGP.

A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.

Equivalent to EmpiricalGetKernelFn with get="nngp" or get="ntk".

class neural_tangents._src.utils.typing.InitFn(*args, **kwargs)[source]

A type alias for initialization functions.

Initialization functions construct parameters for neural networks given a random key and an input shape. Specifically, they produce a tuple giving the output shape and a PyTree of parameters.

neural_tangents._src.utils.typing.Kernels

Kernel inputs/outputs of FanOut, FanInSum, etc.

alias of Union[list[Kernel], tuple[Kernel, …]]

class neural_tangents._src.utils.typing.LayerKernelFn(*args, **kwargs)[source]

A type alias for pure kernel functions.

A pure kernel function takes a PyTree of Kernel object(s) and produces a PyTree of Kernel object(s). These functions are used to define new layer types.

class neural_tangents._src.utils.typing.MaskFn(*args, **kwargs)[source]

A type alias for a masking functions.

Forward-propagate a mask in a layer of a finite-width network.

class neural_tangents._src.utils.typing.MonteCarloKernelFn(*args, **kwargs)[source]

A type alias for Monte Carlo kernel functions.

A kernel function that produces an estimate of an AnalyticKernel by monte carlo sampling given a PRNGKey.

neural_tangents._src.utils.typing.NTTree

Neural Tangents Tree.

Trees of kernels and arrays naturally emerge in certain neural network computations (for example, when neural networks have nested parallel layers).

Mimicking JAX, we use a lightweight tree structure called an NTTree. NTTree has internal nodes that are either lists or tuples and leaves which are either jax.numpy.ndarray or Kernel objects.

alias of Union[list[T], tuple[T, …], T]

neural_tangents._src.utils.typing.NTTrees

A list or tuple of NTTree s.

alias of Union[list[T], tuple[T, …]]

neural_tangents._src.utils.typing.PyTree = typing.Any

A PyTree, see JAX docs for details.

neural_tangents._src.utils.typing.Shapes

A shape - a tuple of integers, or an NTTree of such tuples.

alias of Union[list[tuple[int, …]], tuple[tuple[int, …], …], tuple[int, …]]

neural_tangents._src.utils.typing.VMapAxes

Specifies (input, output, kwargs) axes for vmap in empirical NTK.

alias of Union[Any, None, tuple[Optional[Any], Optional[Any], dict[str, Optional[Any]]]]

Indices and tables