Neural Tangents Reference
Neural Tangents is a set of tools for constructing and training infinitely wide neural networks.
Stax – infinite networks (NNGP, NTK)
Closed-form NNGP and NTK library.
This library contains layer constructors mimicking those in
jax.example_libraries.stax
with similar API apart apart from:
1) Instead of (init_fn, apply_fn)
tuple, layer constructors return a triple
(init_fn, apply_fn, kernel_fn)
, where the added kernel_fn
maps a
Kernel
to a new Kernel
, and represents the change in the
analytic NTK and NNGP kernels (Kernel.nngp
, Kernel.ntk
). These functions
are chained / stacked together within the serial
or parallel
combinators, similarly to init_fn
and apply_fn
.
2) In layers with random weights, NTK parameterization is used by default
(https://arxiv.org/abs/1806.07572, page 3). Standard parameterization
(https://arxiv.org/abs/2001.07301) can be specified for Conv
and Dense
layers by a keyword argument parameterization
.
3) Some functionality may be missing (e.g. BatchNorm
), and some may be
present only in our library (e.g. CIRCULAR
padding, LayerNorm
,
GlobalAvgPool
, GlobalSelfAttention
, flexible batch and channel axes etc.).
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(256, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(512, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>>
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>> y_train)
>>>
>>> # (5, 10) np.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>>
>>> # (5, 10) np.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')
Combinators
Layers to combine multiple other layers into one.
|
Combinator for composing layers in serial. |
|
Combinator for composing layers in parallel. |
Branching
Layers to split outputs into many, or combine many into ones.
|
Layer construction function for a fan-out layer. |
|
Layer construction function for a fan-in concatenation layer. |
Layer construction function for a fan-in product layer. |
|
|
Layer construction function for a fan-in sum layer. |
Linear parametric
Linear layers with trainable parameters.
|
Layer constructor function for a dense (fully-connected) layer. |
|
Layer construction function for a general convolution layer. |
|
Layer construction function for a general unshared convolution layer. |
|
Layer construction function for a general transpose convolution layer. |
|
Layer construction function for (global) scaled dot-product self-attention. |
Linear nonparametric
Linear layers without any trainable parameters.
|
Layer constructor for aggregation operator (graphical neural network). |
|
Layer construction function for an average pooling layer. |
|
Layer construction function for an identity layer. |
|
Layer constructor for a constant (non-trainable) rhs/lhs Dot General. |
|
Dropout layer. |
|
Layer construction function for flattening all non-batch dimensions. |
|
Layer construction function for a global average pooling layer. |
|
Layer construction function for a global sum pooling layer. |
|
Image resize function mimicking |
|
Layer normalisation. |
|
Layer construction function for a 2D sum pooling layer. |
Elementwise nonlinear
Pointwise nonlinear layers.
|
ABReLU nonlinearity, i.e. |
|
Absolute value nonlinearity. |
|
Affine transform of |
|
Elementwise application of |
|
Activation function using numerical integration. |
|
Affine transform of |
|
Elementwise natural exponent function |
|
Simulates the "Gaussian normalized kernel". |
|
Elementwise Gaussian function |
|
Gelu function. |
|
Hermite polynomials. |
|
Leaky ReLU nonlinearity, i.e. |
|
Dual activation function for normalized RBF or squared exponential kernel. |
|
ReLU nonlinearity. |
A sigmoid like function |
|
|
Sign function. |
|
Affine transform of |
Helper enums
Enums for specifying layer properties. Strings can be used in their place.
|
Type of padding in pooling and convolutional layers. |
|
Type of positional embeddings to use in a |
For developers
Classes and decorators helpful for constructing your own layers.
|
A convenience decorator to be added to all public layers like |
|
Returns a decorator that turns layers into layers supporting masking. |
|
Returns a decorator that augments |
|
Helper trinary logic class. |
|
Helps decide whether to allow the kernel to contain diagonal entries only. |
Empirical – finite NNGP and NTK
Compute empirical NNGP and NTK; approximate functions via Taylor series.
All functions in this module are applicable to any JAX functions of proper
signatures (not only those from nt.stax
).
NNGP and NTK are computed using nt.empirical_nngp_fn
, nt.empirical_ntk_fn
,
or nt.empirical_kernel_fn
(for both). The kernels have a very specific output
shape convention that may be unexpected. Further, NTK has multiple
implementations that may perform differently depending on the task. Please read
individual functions’ docstrings.
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>> # A narrow CNN.
>>> init_fn, f, _ = stax.serial(
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>>
>>> _, params = init_fn(key3, x_train.shape)
>>>
>>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>> # network is iid along the batch axis, no BatchNorm. Use default
>>> # `implementation=1` since the network has few trainable parameters.
>>> kernel_fn = nt.empirical_kernel_fn(
>>> f, trace_axes=(-1,), vmap_axes=0, implementation=1)
>>>
>>> # (5, 20) np.ndarray test-train NNGP/NTK
>>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>>
>>> # Full kernel: not reducing over logits.
>>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0)
>>>
>>> # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple.
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>> # A wide FCN with lots of parameters
>>> init_fn, f, _ = stax.serial(
>>> stax.Flatten(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(10)
>>> )
>>>
>>> _, params = init_fn(key3, x_train.shape)
>>>
>>> # Use implicit differentiation in NTK: `implementation=2` to reduce
>>> # memory cost, since the network has many trainable parameters.
>>> ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2)
>>>
>>> # (5, 5) np.ndarray test-test NTK
>>> ntk_test_test = ntk_fn(x_test, None, params)
>>>
>>> # Compute only output variances:
>>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>>
>>> # (20,) np.ndarray train-train diagonal NNGP
>>> nngp_train_train_diag = nngp_fn(x_train, None, params)
Kernel functions
Finite-width NNGP and/or NTK kernel functions.
|
Returns a function that computes single draws from NNGP and NT kernels. |
|
Returns a function to draw a single sample the NNGP of a given network |
|
Returns a function to draw a single sample the NTK of a given network |
Linearization and Taylor expansion
Decorators to Taylor-expand around function parameters.
|
Returns a function |
|
Returns a function |
Predict – inference with NNGP and NTK or linearized networks
Functions to make predictions on the train/test set using NTK/NNGP.
Most functions in this module accept training data as inputs and return a new
function predict_fn
that computes predictions on the train set / given test
set / timesteps.
WARNING: trace_axes
parameter supplied to prediction functions must match the
respective parameter supplied to the function used to compute the kernel.
Namely, this is the same trace_axes
used to compute the empirical kernel
(utils/empirical.py
; diagonal_axes
must be ()
), or channel_axis
in the
output of the top layer used to compute the closed-form kernel (stax.py
; note
that closed-form kernels currently only support a single channel_axis
).
Prediction / inference functions
Functions to make train/test set predictions given NNGP/NTK kernels or the linearized function.
|
Predicts the outcome of function space training using gradient descent. |
|
Predicts the outcome of function space gradient descent training on MSE. |
|
Predicts the gaussian embedding induced by gradient descent on MSE loss. |
|
Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP. |
Utilities
|
Computes the maximal feasible learning rate for infinite width NNs. |
Helper classes
Dataclasses and namedtuples used to return predictions.
|
A |
|
ODE state dataclass holding outputs and auxiliary variables. |
Batching – using multiple devices
Batch kernel computations serially or in parallel.
This module contains a decorator batch
that can be applied to any kernel_fn
of signature kernel_fn(x1, x2, *args, **kwargs)
. The decorated function
performs the same computation by batching over x1
and x2
and concatenating
the result, allowing to both use multiple accelerators and stay within memory
limits.
Note that you typically should not apply the jax.jit
decorator to the
resulting batched_kernel_fn
, as its purpose is explicitly serial execution in
order to save memory. Further, you do not need to apply jax.jit
to the input
kernel_fn
function, as it is JITted internally.
Example
>>> from jax import numpy as np
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> # Define some kernel function.
>>> _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1))
>>>
>>> # Compute the kernel in batches, in parallel.
>>> kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5)
>>>
>>> # Generate dummy input data.
>>> x1, x2 = np.ones((40, 10)), np.ones((80, 10))
>>> kernel_fn_batched(x1, x2) == kernel_fn(x1, x2) # True!
- neural_tangents.batch(kernel_fn, batch_size=0, device_count=- 1, store_on_device=True)[source]
Returns a function that computes a kernel in batches over all devices.
Note that you typically should not apply the
jax.jit
decorator to the resultingbatched_kernel_fn
, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to applyjax.jit
to the inputkernel_fn
function, as it is JITted internally.- Parameters
kernel_fn (
TypeVar
(_KernelFn
, bound=Union
[AnalyticKernelFn
,EmpiricalKernelFn
,EmpiricalGetKernelFn
,MonteCarloKernelFn
])) – A function that computes a kernel on two batches,kernel_fn(x1, x2, *args, **kwargs)
. Herex1
andx2
arenp.ndarray`s of shapes `(n1,) + input_shape
and(n2,) + input_shape
. The kernel function should return aPyTree
.batch_size (
int
) – specifies the size of each batch that gets processed per physical device. Because we parallelize the computation over columns it should be the case thatx1.shape[0]
is divisible bydevice_count * batch_size
andx2.shape[0]
is divisible bybatch_size
.device_count (
int
) – specifies the number of physical devices to be used. Ifdevice_count == -1
all devices are used. Ifdevice_count == 0
, no device parallelism is used (a single default device is used).store_on_device (
bool
) – specifies whether the output should be kept on device or brought back to CPU RAM as it is computed. Defaults toTrue
. Set toFalse
to store and concatenate results using CPU RAM, allowing to compute larger kernels.
- Return type
TypeVar
(_KernelFn
, bound=Union
[AnalyticKernelFn
,EmpiricalKernelFn
,EmpiricalGetKernelFn
,MonteCarloKernelFn
])- Returns
A new function with the same signature as
kernel_fn
that computes the kernel by batching over the dataset in parallel with the specifiedbatch_size
usingdevice_count
devices.
Monte Carlo Sampling
Function to compute Monte Carlo NNGP and NTK estimates.
This module contains a function monte_carlo_kernel_fn
that allow to compute
Monte Carlo estimates of NNGP and NTK kernels of arbitrary functions. For more
details on how individual samples are computed, refer to utils/empirical.py
.
Note that the monte_carlo_kernel_fn
accepts arguments like batch_size
,
device_count
, and store_on_device
, and is appropriately batched /
parallelized. You don’t need to apply the nt.batch
or jax.jit
decorators to
it. Further, you do not need to apply jax.jit
to the input apply_fn
function, as the resulting empirical kernel function is JITted internally.
- neural_tangents.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, batch_size=0, device_count=- 1, store_on_device=True, trace_axes=(- 1,), diagonal_axes=(), vmap_axes=None, implementation=1)[source]
Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.
Note that the returned function is appropriately batched / parallelized. You don’t need to apply the
nt.batch
orjax.jit
decorators to it. Further, you do not need to applyjax.jit
to the inputapply_fn
function, as the resulting empirical kernel function is JITted internally.- Parameters
init_fn (
InitFn
) – a function initializing parameters of the neural network. Fromjax.example_libraries.stax
: “takes an rng key and an input shape and returns an(output_shape, params)
pair”.apply_fn (
ApplyFn
) – a function computing the output of the neural network. Fromjax.example_libraries.stax
: “takes params, inputs, and an rng key and applies the layer”.key (
KeyArray
) – RNG (jax.random.PRNGKey
) for sampling random networks. Must have shape(2,)
.n_samples (
Union
[int
,Iterable
[int
]]) – number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: usen_samples=[2**k for k in range(10)]
for the generator to yield estimates using 1, 2, 4, …, 512 Monte Carlo samples.batch_size (
int
) – an integer making the kernel computed in batches ofx1
andx2
of this size.0
means computing the whole kernel. Must dividex1.shape[0]
andx2.shape[0]
.device_count (
int
) – an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores).-1
means use all available devices.0
means compute on a single device sequentially. If not0
, must dividex1.shape[0]
.store_on_device (
bool
) – a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit.trace_axes (
Union
[int
,Sequence
[int
]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis intrace_axes
). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to aconstant * identity matrix
in the limit of interest (e.g. infinite width or infiniten_samples
). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infiniten_samples
limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)diagonal_axes (
Union
[int
,Sequence
[int
]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis indiagonal_axes
). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infiniten_samples
). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes intrace_axes
instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)vmap_axes (
Union
[List
[int
],Tuple
[int
,...
],int
,None
,Tuple
[Union
[List
[int
],Tuple
[int
,...
],int
,None
],Union
[List
[int
],Tuple
[int
,...
],int
,None
],Dict
[str
,Union
[List
[int
],Tuple
[int
,...
],int
,None
]]]]) – applicable only to NTK. A triple of(in_axes, out_axes, kwargs_axes)
passed tovmap
to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies thatf(params, x, **kwargs)
equals to a concatenation alongout_axes
off
applied to slices ofx
and**kwargs
alongin_axes
andkwargs_axes
, i.e.f
can be evaluated as avmap
. This allows to evaluate Jacobians much more efficiently. Ifvmap_axes
is not a triple, it is interpreted asin_axes = out_axes = vmap_axes, kwargs_axes = {}
. For example a very common usecase isvmap_axes=0
for a neural network with leading (0
) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case ofnt.stax
, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all,vmap_axes
must be set toNone
, to avoid wrong (and potentially silent) results.implementation (
int
) –applicable only to NTK.
1
or2
.1
directly instantiates Jacobians and computes their outer product.2
uses implicit differentiation to avoid instantiating whole Jacobians at once. The implicit kernel is derived by observing that: \(\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)\), i.e. a linear function \([J(X_1) J(X_2)^T]\) applied to an identity matrix \(I\). This allows the computation of the NTK to be phrased as: \(a(v) = J(X_2)^T v\), which is computed by a vector-Jacobian product; \(b(v) = J(X_1) a(v)\) which is computed by a Jacobian-vector product; and \(\Theta = [b(v)] / d[v^T](I)\) which is computed via avmap
of \(b(v)\) over columns of the identity matrix \(I\).It is best to benchmark each method on your specific task. We suggest using
1
unless you get OOMs due to large number of trainable parameters, otherwise -2
.
- Return type
MonteCarloKernelFn
- Returns
If
n_samples
is an integer, returns a function of signaturekernel_fn(x1, x2, get)
that returns an MC estimation of the kernel usingn_samples
. Ifn_samples
is a collection of integers,kernel_fn(x1, x2, get)
returns a generator that yields estimates usingn
samples forn in n_samples
.
Example
>>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> >>> key1, key2 = random.split(random.PRNGKey(1), 2) >>> x_train = random.normal(key1, (20, 32, 32, 3)) >>> y_train = random.uniform(key1, (20, 10)) >>> x_test = random.normal(key2, (5, 32, 32, 3)) >>> >>> init_fn, apply_fn, _ = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> stax.Conv(256, (3, 3)), >>> stax.Relu(), >>> stax.Conv(512, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> >>> n_samples = 200 >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples) >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk')) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`. >>> >>> n_samples = [1, 10, 100, 1000] >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, >>> n_samples) >>> kernel_samples = kernel_fn_generator(x_train, x_test, >>> get=('nngp', 'ntk')) >>> for n, kernel in zip(n_samples, kernel_samples): >>> print(n, kernel) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.