Neural Tangents Reference
Neural Tangents (neural_tangents
– nt
) is a set of tools for constructing and training infinitely wide
neural networks (a.k.a. NTK, NNGP).
nt.stax
– infinite NNGP and NTK
Closed-form NNGP and NTK library.
This library contains layers mimicking those in
jax.example_libraries.stax
with similar API apart from:
1) Instead of (init_fn, apply_fn)
tuple, layers return a triple
(init_fn, apply_fn, kernel_fn)
, where the added kernel_fn
maps a
Kernel
to a new Kernel
, and
represents the change in the analytic NTK and NNGP kernels
(nngp
, ntk
).
These functions are chained / stacked together within the serial
or
parallel
combinators, similarly to init_fn
and apply_fn
.
For details, please see “Neural Tangents: Fast and Easy Infinite Neural
Networks in Python”.
2) In layers with random weights, NTK parameterization is used by default
(see page 3 in
“Neural Tangent Kernel: Convergence and Generalization in Neural Networks”). Standard parameterization can be
specified for Conv
and Dense
layers by a keyword argument
parameterization
. For details, please see “On the infinite width limit of
neural networks with a standard parameterization”.
3) Some functionality may be missing (e.g.
jax.example_libraries.stax.BatchNorm
), and some may be
present only in our library (e.g. CIRCULAR
padding,
LayerNorm
, GlobalAvgPool
, GlobalSelfAttention
, flexible
batch and channel axes etc.).
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(256, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(512, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>> #
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>> y_train)
>>> #
>>> # (5, 10) jnp.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>> #
>>> # (5, 10) jnp.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')
Combinators
Layers to combine multiple other layers into one.
|
Combinator for composing layers in parallel. |
|
Compose |
|
Combinator for composing layers in serial. |
Branching
Layers to split outputs into many, or combine many into ones.
|
Fan-in concatenation. |
Fan-in product. |
|
|
Fan-in sum. |
|
Fan-out. |
Linear parametric
Linear layers with trainable parameters.
|
General convolution. |
|
General unshared convolution. |
|
General transpose convolution. |
|
Dense (fully-connected, matrix product). |
|
Global scaled dot-product self-attention. |
Linear nonparametric
Linear layers without any trainable parameters.
|
Aggregation operator (graphical neural network). |
|
Average pooling. |
|
Constant (non-trainable) rhs/lhs Dot General. |
|
Dropout. |
|
Flattening all non-batch dimensions. |
|
Global average pooling. |
|
Global sum pooling. |
|
Identity (no-op). |
|
Image resize function mimicking |
|
Index into the array mimicking |
|
Layer normalisation. |
|
Sum pooling. |
Elementwise nonlinear
Pointwise nonlinear layers. For details, please see “Fast Neural Kernel Embeddings for General Activations”.
|
ABReLU nonlinearity, i.e. |
|
Absolute value nonlinearity. |
|
Affine transform of |
|
Elementwise application of |
|
Activation function using numerical integration. |
|
Affine transform of |
|
Elementwise natural exponent function |
|
Simulates the "Gaussian normalized kernel". |
|
Gabor function |
|
Elementwise Gaussian function |
|
Gelu function. |
|
Hermite polynomials. |
|
Leaky ReLU nonlinearity, i.e. |
|
Monomials, i.e. |
|
Polynomials, i.e. |
|
Dual activation function for normalized RBF or squared exponential kernel. |
|
Rectified monomials, i.e. |
|
ReLU nonlinearity. |
A sigmoid like function |
|
|
Sign function. |
|
Affine transform of |
Helper classes
Utility classes for specifying layer properties. For enums, strings can be passed in their place.
|
Implementation of the |
|
Type of nonlinearity to use in a |
|
Type of padding in pooling and convolutional layers. |
|
Type of positional embeddings to use in a |
For developers
Classes and decorators helpful for constructing your own layers.
|
Helper trinary logic class. |
|
Helps decide whether to allow the kernel to contain diagonal entries only. |
|
A convenience decorator to be added to all public layers. |
|
Returns a decorator that augments |
|
Returns a decorator that turns layers into layers supporting masking. |
nt.empirical
– finite NNGP and NTK
Compute empirical NNGP and NTK; approximate functions via Taylor series.
All functions in this module are applicable to any JAX functions of proper
signatures (not only those from stax
).
NNGP and NTK are computed using empirical_nngp_fn
,
empirical_ntk_fn
, or
empirical_kernel_fn
(for both). The kernels have a very
specific output shape convention that may be unexpected. Further, NTK has
multiple implementations that may perform differently depending on the task.
Please read individual functions’ docstrings.
For details, please see “Fast Finite Width Neural Tangent Kernel”.
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> # A narrow CNN.
>>> init_fn, f, _ = stax.serial(
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>> #
>>> _, params = init_fn(key3, x_train.shape)
>>> #
>>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>> # network is iid along the batch axis, no BatchNorm. Use default
>>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`).
>>> kernel_fn = nt.empirical_kernel_fn(
>>> f, trace_axes=(-1,), vmap_axes=0,
>>> implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
>>> #
>>> # (5, 20) jnp.ndarray test-train NNGP/NTK
>>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>> #
>>> # Full kernel: not reducing over logits. Use structured derivatives
>>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for
>>> # typically faster computation and lower memory cost.
>>> kernel_fn = nt.empirical_kernel_fn(
>>> f, trace_axes=(), vmap_axes=0,
>>> implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
>>> #
>>> # (5, 20, 10, 10) jnp.ndarray test-train NNGP/NTK namedtuple.
>>> k_test_train = kernel_fn(x_test, x_train, None, params)
>>> #
>>> # A wide FCN with lots of parameters and many (`100`) outputs.
>>> init_fn, f, _ = stax.serial(
>>> stax.Flatten(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(100)
>>> )
>>> #
>>> _, params = init_fn(key3, x_train.shape)
>>> #
>>> # Use ntk-vector products
>>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the
>>> # network has many parameters relative to the cost of forward pass,
>>> # large outputs.
>>> ntk_fn = nt.empirical_ntk_fn(
>>> f, vmap_axes=0,
>>> implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
>>> #
>>> # (5, 5) jnp.ndarray test-test NTK
>>> ntk_test_test = ntk_fn(x_test, None, params)
>>> #
>>> # Compute only output variances:
>>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>> #
>>> # (20,) jnp.ndarray train-train diagonal NNGP
>>> nngp_train_train_diag = nngp_fn(x_train, None, params)
Kernel functions
Finite-width NNGP and/or NTK kernel functions.
|
Returns a function that computes single draws from NNGP and NT kernels. |
|
Returns a function to draw a single sample the NNGP of a given network |
|
Returns a function to draw a single sample the NTK of a given network |
NTK implementation
An enum.IntEnum
specifying NTK implementation method.
- class neural_tangents.NtkImplementation(value)[source]
Implementation method of the underlying finite width NTK computation.
Below is a very brief summary of each method. For details, please see “Fast Finite Width Neural Tangent Kernel”.
- AUTO
(or
0
) evaluates FLOPs of all other methods at compilation time, and selects the fastest method. However, at the time it only works correctly on TPUs, and on CPU/GPU can return wrong results, which is why it is not the default. TODO(romann): revisit based on http://b/202218145.
- JACOBIAN_CONTRACTION
(or
1
) computes the NTK as the outer product of two Jacobians, each computed using reverse-mode Autodiff (vector-Jacobian products, VJPs). When JITted, the contraction is performed in a layerwise fashion, so that entire Jacobians aren’t necessarily instantiated in memory at once, and the memory usage of the method can be lower than memory needed to instantiate the two Jacobians. This method is best suited for networks with small outputs (such as scalar outputs for binary classification or regression, as opposed to 1000 ImageNet classes), and an expensive forward pass relative to the number of parameters (such as CNNs, where forward pass reuses a small filter bank many times). It is also the the most reliable method, since its implementation is simplest, and reverse-mode Autodiff is most commonly used and well tested elsewhere. For this reason it is set as the default.
- NTK_VECTOR_PRODUCTS
(or
2
) computes the NTK as a sequence of NTK-vector products, similarly to how a Jacobian is computed as a sequence of Jacobian-vector products (JVPs) or vector-Jacobian products (VJPs). This amounts to using both forward (JVPs) and reverse (VJPs) mode Autodiff, and allows to eliminate the Jacobian contraction at the expense of additional forward passes. Therefore this method is recommended for networks with a cheap forward pass relative to the number of parameters (e.g. fully-connected networks, where each parameter matrix is used only once in the forward pass), and networks with large outputs (e.g. 1000 ImageNet classes). Memory requirements of this method are same asJACOBIAN_CONTRACTION
(1
). Due to reliance of forward-mode Autodiff, this method is slightly more prone to JAX and XLA bugs thanJACOBIAN_CONTRACTION
(1
), but overall is quite simple and reliable.
- STRUCTURED_DERIVATIVES
(or
3
) uses a custom JAX interpreter to compute the NTK more efficiently than other methods. It traverses the computational graph of a function in the same order as during reverse-mode Autodiff, but instead of computing VJPs, it directly computes MJJMPs, “matrix-Jacobian-Jacobian-matrix” products, which arise in the computation of an NTK. Each MJJMP computation relies on the structure in the Jacobians, hence the name. This method can be dramatically faster (up to several orders of magnitude) then other methods on fully-connected networks, and is usually faster or equivalent on CNNs, Transformers, and other architectures, but exact speedup (e.g. from no speedup to 10X) depends on each specific setting. It can also use less memory than other methods. In our experience it consistently outperforms other methods in most settings. However, its implementation is significantly more complex (hence bug-prone), and it doesn’t yet support functions using more exotic JAX primitives (e.g.jax.checkpoint
, parallel collectives such asjax.lax.psum
, compiled loops likejax.lax.scan
, etc.), which is why it is highly-recommended to try, but not set as the default yet.
NTK-vector products
A function to compute NTK-vector products without instantiating the NTK.
|
Returns an NTK-vector product function. |
Linearization and Taylor expansion
Decorators to Taylor-expand around function parameters.
|
Returns a function |
|
Returns a function |
nt.predict
– inference w/ NNGP & NTK
Functions to make predictions on the train/test set using NTK/NNGP.
Most functions in this module accept training data as inputs and return a new
function predict_fn
that computes predictions on the train set / given test
set / timesteps.
Warning
trace_axes
parameter supplied to prediction functions must match the
respective parameter supplied to the function used to compute the kernel.
Namely, this is the same trace_axes
used to compute the empirical kernel
(utils/empirical.py
; diagonal_axes
must be ()
), or channel_axis
in the
output of the top layer used to compute the closed-form kernel (stax.py
;
note that closed-form kernels currently only support a single channel_axis
).
Prediction / inference functions
Functions to make train/test set predictions given NNGP/NTK kernels or the linearized function.
|
Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP. |
|
Predicts the outcome of function space training using gradient descent. |
|
Predicts the outcome of function space gradient descent training on MSE. |
|
Predicts the gaussian embedding induced by gradient descent on MSE loss. |
Utilities
|
Computes the maximal feasible learning rate for infinite width NNs. |
Helper classes
Dataclasses and namedtuples used to return predictions.
|
A |
|
ODE state dataclass holding outputs and auxiliary variables. |
nt.batch
– using multiple devices
Batch kernel computations serially or in parallel.
This module contains a decorator batch
that can be applied to any kernel_fn
of signature kernel_fn(x1, x2, *args, **kwargs)
. The decorated function
performs the same computation by batching over x1
and x2
and concatenating
the result, allowing to both use multiple accelerators and stay within memory
limits.
Note that you typically should not apply the jax.jit
decorator to the
resulting batched_kernel_fn
, as its purpose is explicitly serial execution in
order to save memory. Further, you do not need to apply jax.jit
to the
input kernel_fn
function, as it is JITted internally.
Example
>>> from jax import numpy as jnp
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> # Define some kernel function.
>>> _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1))
>>> #
>>> # Compute the kernel in batches, in parallel.
>>> kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5)
>>> #
>>> # Generate dummy input data.
>>> x1, x2 = jnp.ones((40, 10)), jnp.ones((80, 10))
>>> kernel_fn_batched(x1, x2) == kernel_fn(x1, x2) # True!
- neural_tangents.batch(kernel_fn, batch_size=0, device_count=-1, store_on_device=True)[source]
Returns a function that computes a kernel in batches over all devices.
Note that you typically should not apply the
jax.jit
decorator to the resultingbatched_kernel_fn
, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to applyjax.jit
to the inputkernel_fn
function, as it is JITted internally.- Parameters:
kernel_fn (
TypeVar
(_KernelFn
, bound=Union
[AnalyticKernelFn
,EmpiricalKernelFn
,EmpiricalGetKernelFn
,MonteCarloKernelFn
])) – A function that computes a kernel on two batches, kernel_fn(x1, x2, *args, **kwargs). Here x1 and x2 are jnp.ndarray`s of shapes `(n1,) + input_shape and (n2,) + input_shape. The kernel function should return a PyTree.batch_size (
int
) – specifies the size of each batch that gets processed per physical device. Because we parallelize the computation over columns it should be the case that x1.shape[0] is divisible by device_count * batch_size and x2.shape[0] is divisible by batch_size.device_count (
int
) – specifies the number of physical devices to be used. If device_count == -1 all devices are used. If device_count == 0, no device parallelism is used (a single default device is used).store_on_device (
bool
) – specifies whether the output should be kept on device or brought back to CPU RAM as it is computed. Defaults to True. Set to False to store and concatenate results using CPU RAM, allowing to compute larger kernels.
- Return type:
TypeVar
(_KernelFn
, bound=Union
[AnalyticKernelFn
,EmpiricalKernelFn
,EmpiricalGetKernelFn
,MonteCarloKernelFn
])- Returns:
A new function with the same signature as kernel_fn that computes the kernel by batching over the dataset in parallel with the specified batch_size using device_count devices.
nt.monte_carlo_kernel_fn
- MC Sampling
Function to compute Monte Carlo NNGP and NTK estimates.
This module contains a function monte_carlo_kernel_fn
that allow to compute
Monte Carlo estimates of NNGP and NTK kernels of arbitrary functions. For more
details on how individual samples are computed, refer to utils/empirical.py
.
Note that the monte_carlo_kernel_fn
accepts arguments like batch_size
,
device_count
, and store_on_device
, and is appropriately batched /
parallelized. You don’t need to apply the batch
or
jax.jit
decorators to it. Further, you do not need to apply
jax.jit
to the input apply_fn
function, as the resulting empirical
kernel function is JITted internally.
- neural_tangents.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, batch_size=0, device_count=-1, store_on_device=True, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]
Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.
Note that the returned function is appropriately batched / parallelized. You don’t need to apply the
nt.batch
orjax.jit
decorators to it. Further, you do not need to applyjax.jit
to the inputapply_fn
function, as the resulting empirical kernel function is JITted internally.- Parameters:
init_fn (
InitFn
) – a function initializing parameters of the neural network. Fromjax.example_libraries.stax
: “takes an rng key and an input shape and returns an (output_shape, params) pair”.apply_fn (
ApplyFn
) – a function computing the output of the neural network. Fromjax.example_libraries.stax
: “takes params, inputs, and an rng key and applies the layer”.key (
Array
) – RNG (jax.random.PRNGKey) for sampling random networks. Must have shape (2,).n_samples (
Union
[int
,Iterable
[int
]]) – number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: use n_samples=[2**k for k in range(10)] for the generator to yield estimates using 1, 2, 4, …, 512 Monte Carlo samples.batch_size (
int
) – an integer making the kernel computed in batches of x1 and x2 of this size. 0 means computing the whole kernel. Must divide x1.shape[0] and x2.shape[0].device_count (
int
) – an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores). -1 means use all available devices. 0 means compute on a single device sequentially. If not 0, must divide x1.shape[0].store_on_device (
bool
) – a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit.trace_axes (
Union
[int
,Sequence
[int
]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)diagonal_axes (
Union
[int
,Sequence
[int
]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)vmap_axes (
Union
[Any
,None
,tuple
[Optional
[Any
],Optional
[Any
],dict
[str
,Optional
[Any
]]]]) – applicable only to NTK. A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes, i.e. f can be evaluated as a vmap. This allows to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common usecase is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.implementation (
Union
[int
,NtkImplementation
]) – Applicable only to NTK, anNtkImplementation
value (or anint
0, 1, 2, or 3). See theNtkImplementation
docstring for details._j_rules (
bool
) – Internal debugging parameter, applicable only to NTK when implementation isSTRUCTURED_DERIVATIVES
(3) orAUTO
(0). Set to True to allow custom Jacobian rules for intermediary primitive dy/dw computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to False to use JVPs or VJPs, via JAX’sjax.jacfwd
orjax.jacrev
. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to False could improve performance._s_rules (
bool
) – Internal debugging parameter, applicable only to NTK when implementation isSTRUCTURED_DERIVATIVES
(3) orAUTO
(0). Set to True to allow efficient MJJMp rules for structured dy/dw primitive Jacobians. In practice should be set to True, and setting it to False can lead to dramatic deterioration of performance._fwd (
Optional
[bool
]) – Internal debugging parameter, applicable only to NTK when implementation isSTRUCTURED_DERIVATIVES
(3) orAUTO
(0). Set to True to allowjax.jvp
in intermediary primitive Jacobian dy/dw computations, False to always usejax.vjp
. None to decide automatically based on input/output sizes. Applicable when _j_rules=False, or when a primitive does not have a Jacobian rule. Should be set to None for best performance.
- Return type:
- Returns:
If n_samples is an integer, returns a function of signature kernel_fn(x1, x2, get) that returns an MC estimation of the kernel using n_samples. If n_samples is a collection of integers, kernel_fn(x1, x2, get) returns a generator that yields estimates using n samples for n in n_samples.
Example
>>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> # >>> key1, key2 = random.split(random.PRNGKey(1), 2) >>> x_train = random.normal(key1, (20, 32, 32, 3)) >>> y_train = random.uniform(key1, (20, 10)) >>> x_test = random.normal(key2, (5, 32, 32, 3)) >>> # >>> init_fn, apply_fn, _ = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> stax.Conv(256, (3, 3)), >>> stax.Relu(), >>> stax.Conv(512, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> # >>> n_samples = 200 >>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples) >>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk')) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`. >>> # >>> n_samples = [1, 10, 100, 1000] >>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, >>> n_samples) >>> kernel_samples = kernel_fn_generator(x_train, x_test, >>> get=('nngp', 'ntk')) >>> for n, kernel in zip(n_samples, kernel_samples): >>> print(n, kernel) >>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.
nt.experimental
– prototypes
Warning
This module contains new highly-experimental prototypes. Please beware that they are not properly tested, not supported, and may suffer from sub-optimal performance. Use at your own risk!
Kernel functions
Finite-width NTK kernel function in Tensorflow. See the Python and Colab usage examples.
- neural_tangents.experimental.empirical_ntk_fn_tf(f, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]
Returns a function to draw a single sample the NTK of a given network
f
.This function follows the API of
neural_tangents.empirical_ntk_fn
, but is applicable to Tensorflowtf.Module
,tf.keras.Model
, ortf.function
, via a TF->JAX->TF roundtrip usingtf2jax
andjax2tf
. Docstring below adapted fromneural_tangents.empirical_ntk_fn
.Warning
This function is experimental and risks returning wrong results or performing slowly. It is intended to demonstrate the usage of
neural_tangents.empirical_ntk_fn
in Tensorflow, but has not been extensively tested. Specifically, it appears to have very long compile times (but OK runtime), is prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model.TODO(romann): support division between trainable and non-trainable variables.
TODO(romann): investigate slow compile times.
- Parameters:
f (
Union
[Module
,PolymorphicFunction
]) –tf.Module
ortf.function
whose NTK we are computing. Must satisfy the following:if a
tf.function
, must have the signature of f(params, x).if a
tf.Module
, must be either atf.keras.Model
, or be callable.input signature (f.input_shape for
tf.Module
ortf.keras.Model
, or f.input_signature for tf.function) must be known.
trace_axes (
Union
[int
,Sequence
[int
]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)diagonal_axes (
Union
[int
,Sequence
[int
]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)vmap_axes (
Union
[Any
,None
,tuple
[Optional
[Any
],Optional
[Any
],dict
[str
,Optional
[Any
]]]]) –A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f.call(x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes. In other words, it certifies that f can be evaluated as a vmap with out_axes=out_axes over x (along in_axes) and those arguments in **kwargs that are present in kwargs_axes.keys() (along kwargs_axes.values()).
This allows us to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.
implementation (
Union
[NtkImplementation
,int
]) – AnNtkImplementation
value (or anint
0, 1, 2, or 3). See theNtkImplementation
docstring for details._j_rules (
bool
) – Internal debugging parameter, applicable only when implementation isSTRUCTURED_DERIVATIVES
(3) orAUTO
(0). Set to True to allow custom Jacobian rules for intermediary primitive dy/dw computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to False to use JVPs or VJPs, via JAX’sjax.jacfwd
orjax.jacrev
. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to False could improve performance._s_rules (
bool
) – Internal debugging parameter, applicable only when implementation isSTRUCTURED_DERIVATIVES
(3) orAUTO
(0). Set to True to allow efficient MJJMp rules for structured dy/dw primitive Jacobians. In practice should be set to True, and setting it to False can lead to dramatic deterioration of performance._fwd (
Optional
[bool
]) – Internal debugging parameter, applicable only when implementation isSTRUCTURED_DERIVATIVES
(3) orAUTO
(0). Set to True to allowjax.jvp
in intermediary primitive Jacobian dy/dw computations, False to always usejax.vjp
. None to decide automatically based on input/output sizes. Applicable when _j_rules=False, or when a primitive does not have a Jacobian rule. Should be set to None for best performance.
- Return type:
- Returns:
A function ntk_fn that computes the empirical ntk.
Helper functions
A helper function to convert Tensorflow stateful models into functional-style, stateless apply_fn(params, x) forward pass function and extract the respective params.
- neural_tangents.experimental.get_apply_fn_and_params(f)[source]
Converts a
tf.Module
into a forward-pass apply_fn and params.Use this function to extract params to pass to the Tensorflow empirical NTK kernel function.
Warning
This function does not distinguish between trainable and non-trainable parameters of the model.
- Parameters:
f (
Module
) – atf.Module
to convert to a apply_fn(params, x) function. Must have an input_shape attribute set (specifying shape of x), and be callable or be atf.keras.Model
.- Returns:
A tuple fo (apply_fn, params), where params is a PyTree[tf.Tensor].
Kernel
dataclass
- class neural_tangents.Kernel(nngp, ntk, cov1, cov2, x1_is_x2, is_gaussian, is_reversed, is_input, diagonal_batch, diagonal_spatial, shape1, shape2, batch_axis, channel_axis, mask1=None, mask2=None)[source]
Dataclass containing information about the NTK and NNGP of a model.
- nngp
covariance between the first and second batches (NNGP). A jnp.ndarray of shape (batch_size_1, batch_size_2, height, [height,], width, [width,], …)), where exact shape depends on diagonal_spatial.
- ntk
the neural tangent kernel (NTK). jnp.ndarray of same shape as nngp.
- cov1
covariance of the first batch of inputs. A jnp.ndarray with shape (batch_size_1, [batch_size_1,] height, [height,], width, [width,], …) where exact shape depends on diagonal_batch and diagonal_spatial.
- cov2
optional covariance of the second batch of inputs. A jnp.ndarray with shape (batch_size_2, [batch_size_2,] height, [height,], width, [width,], …) where the exact shape depends on diagonal_batch and diagonal_spatial.
- x1_is_x2
a boolean specifying whether x1 and x2 are the same.
- is_gaussian
a boolean, specifying whether the output features or channels of the layer / NN function (returning this Kernel as the kernel_fn) are i.i.d. Gaussian with covariance nngp, conditioned on fixed inputs to the layer and i.i.d. Gaussian weights and biases of the layer. For example, passing an input through a CNN layer with i.i.d. Gaussian weights and biases produces i.i.d. Gaussian random variables along the channel dimension, while passing an input through a nonlinearity does not.
- is_reversed
a boolean specifying whether the covariance matrices nngp, cov1, cov2, and ntk have the ordering of spatial dimensions reversed. Ignored unless diagonal_spatial is False. Used internally to avoid self-cancelling transpositions in a sequence of CNN layers that flip the order of kernel spatial dimensions.
- is_input
a boolean specifying whether the current layer is the input layer, and it is used to avoid applying dropout to the input layer.
- diagonal_batch
a boolean specifying whether cov1 and cov2 store only the diagonal of the sample-sample covariance (diagonal_batch == True, cov1.shape == (batch_size_1, …)), or the full covariance (diagonal_batch == False, cov1.shape == (batch_size_1, batch_size_1, …)). Defaults to True as no current layers require the full covariance.
- diagonal_spatial
a boolean specifying whether all (cov1, ntk, etc.) covariance matrices store only the diagonals of the location-location covariances (diagonal_spatial == True, nngp.shape == (batch_size_1, batch_size_2, height, width, depth, …)), or the full covariance (diagonal_spatial == False, nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, …)). Defaults to False, but is set to True if the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers and Flatten on top).
- shape1
a tuple specifying the shape of the random variable in the first batch of inputs. These have covariance cov1 and covariance with the second batch of inputs given by nngp.
- shape2
a tuple specifying the shape of the random variable in the second batch of inputs. These have covariance cov2 and covariance with the first batch of inputs given by nngp.
- batch_axis
the batch axis of the activations.
- channel_axis
channel axis of the activations (taken to infinity).
- mask1
an optional boolean jnp.ndarray with a shape broadcastable to shape1 (and the same number of dimensions). True stands for the input being masked at that position, while False means the input is visible. For example, if shape1 == (5, 32, 32, 3) (a batch of 5 NHWC CIFAR10 images), a mask1 of shape (5, 1, 32, 1) means different images can have different blocked columns (H and C dimensions are always either both blocked or unblocked). None means no masking.
- mask2
same as mask1, but for the second batch of inputs.
- asdict(*, dict_factory=<class 'dict'>)
Instance method alternative to dataclasses.asdict.
- astuple(*, tuple_factory=<class 'tuple'>)
Instance method alternative to dataclasses.astuple.
- dot_general(other1, other2, is_lhs, dimension_numbers)[source]
Covariances of
jax.lax.dot_general
of x1/2 with other1/2.- Return type:
- replace(**changes)
Instance method alternative to dataclasses.replace.
- reverse()[source]
Reverse the order of spatial axes in the covariance matrices.
- Return type:
- Returns:
A Kernel object with spatial axes order flipped in all covariance matrices. For example, if kernel.nngp has shape (batch_size_1, batch_size_2, H, H, W, W, D, D, …), then reverse(kernels).nngp has shape (batch_size_1, batch_size_2, …, D, D, W, W, H, H).
- transpose(axes=None)[source]
Permute spatial dimensions of the Kernel according to axes.
Follows https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html
Note that axes apply only to spatial axes, batch axes are ignored and remain leading in all covariance arrays, and channel axes are not present in a Kernel object. If the covariance array is of shape (batch_size, X, X, Y, Y), and axes == (0, 1), resulting array is of shape (batch_size, Y, Y, X, X).
- Return type:
Typing
Common Type Definitions.
- class neural_tangents._src.utils.typing.AnalyticKernelFn(*args, **kwargs)[source]
A type alias for analytic kernel functions.
A kernel function that computes an analytic kernel. Takes either a
Kernel
orjax.numpy.ndarray
inputs and aget
argument that specifies what quantities should be computed by the kernel. Returns either aKernel
object orjax.numpy.ndarray
-s for kernels specified byget
.
- class neural_tangents._src.utils.typing.ApplyFn(*args, **kwargs)[source]
A type alias for apply functions.
Apply functions do computations with finite-width neural networks. They are functions that take a PyTree of parameters and an array of inputs and produce an array of outputs.
- neural_tangents._src.utils.typing.Axes
Axes specification, can be integers (
axis=-1
) or sequences (axis=(1, 3)
).
- class neural_tangents._src.utils.typing.EmpiricalGetKernelFn(*args, **kwargs)[source]
A type alias for empirical kernel functions accepting a
get
argument.A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.
Equivalent to
EmpiricalKernelFn
, but accepts aget
argument, which can be for exampleget=("nngp", "ntk")
, to compute both kernels together.
- class neural_tangents._src.utils.typing.EmpiricalKernelFn(*args, **kwargs)[source]
A type alias for empirical kernel functions computing either NTK or NNGP.
A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.
Equivalent to
EmpiricalGetKernelFn
withget="nngp"
orget="ntk"
.
- class neural_tangents._src.utils.typing.InitFn(*args, **kwargs)[source]
A type alias for initialization functions.
Initialization functions construct parameters for neural networks given a random key and an input shape. Specifically, they produce a tuple giving the output shape and a PyTree of parameters.
- neural_tangents._src.utils.typing.Kernels
Kernel inputs/outputs of
FanOut
,FanInSum
, etc.
- class neural_tangents._src.utils.typing.LayerKernelFn(*args, **kwargs)[source]
A type alias for pure kernel functions.
A pure kernel function takes a PyTree of Kernel object(s) and produces a PyTree of Kernel object(s). These functions are used to define new layer types.
- class neural_tangents._src.utils.typing.MaskFn(*args, **kwargs)[source]
A type alias for a masking functions.
Forward-propagate a mask in a layer of a finite-width network.
- class neural_tangents._src.utils.typing.MonteCarloKernelFn(*args, **kwargs)[source]
A type alias for Monte Carlo kernel functions.
A kernel function that produces an estimate of an
AnalyticKernel
by monte carlo sampling given aPRNGKey
.
- neural_tangents._src.utils.typing.NTTree
Neural Tangents Tree.
Trees of kernels and arrays naturally emerge in certain neural network computations (for example, when neural networks have nested parallel layers).
Mimicking JAX, we use a lightweight tree structure called an
NTTree
.NTTree
has internal nodes that are either lists or tuples and leaves which are eitherjax.numpy.ndarray
orKernel
objects.