nt.monte_carlo_kernel_fn - MC Sampling

Function to compute Monte Carlo NNGP and NTK estimates.

This module contains a function monte_carlo_kernel_fn that allow to compute Monte Carlo estimates of NNGP and NTK kernels of arbitrary functions. For more details on how individual samples are computed, refer to utils/empirical.py.

Note that the monte_carlo_kernel_fn accepts arguments like batch_size, device_count, and store_on_device, and is appropriately batched / parallelized. You don’t need to apply the batch or jax.jit decorators to it. Further, you do not need to apply jax.jit to the input apply_fn function, as the resulting empirical kernel function is JITted internally.

neural_tangents.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, batch_size=0, device_count=-1, store_on_device=True, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]

Return a Monte Carlo sampler of NTK and NNGP kernels of a given function.

Note that the returned function is appropriately batched / parallelized. You don’t need to apply the nt.batch or jax.jit decorators to it. Further, you do not need to apply jax.jit to the input apply_fn function, as the resulting empirical kernel function is JITted internally.

Parameters:
  • init_fn (InitFn) – a function initializing parameters of the neural network. From jax.example_libraries.stax: “takes an rng key and an input shape and returns an (output_shape, params) pair”.

  • apply_fn (ApplyFn) – a function computing the output of the neural network. From jax.example_libraries.stax: “takes params, inputs, and an rng key and applies the layer”.

  • key (Array) – RNG (jax.random.PRNGKey) for sampling random networks. Must have shape (2,).

  • n_samples (Union[int, Iterable[int]]) – number of Monte Carlo samples. Can be either an integer or an iterable of integers at which the resulting generator will yield estimates. Example: use n_samples=[2**k for k in range(10)] for the generator to yield estimates using 1, 2, 4, …, 512 Monte Carlo samples.

  • batch_size (int) – an integer making the kernel computed in batches of x1 and x2 of this size. 0 means computing the whole kernel. Must divide x1.shape[0] and x2.shape[0].

  • device_count (int) – an integer making the kernel be computed in parallel across this number of devices (e.g. GPUs or TPU cores). -1 means use all available devices. 0 means compute on a single device sequentially. If not 0, must divide x1.shape[0].

  • store_on_device (bool) – a boolean, indicating whether to store the resulting kernel on the device (e.g. GPU or TPU), or in the CPU RAM, where larger kernels may fit.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Union[Any, None, tuple[Optional[Any], Optional[Any], dict[str, Optional[Any]]]]) – applicable only to NTK. A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes, i.e. f can be evaluated as a vmap. This allows to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common usecase is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (Union[int, NtkImplementation]) – Applicable only to NTK, an NtkImplementation value (or an int 0, 1, 2, or 3). See the NtkImplementation docstring for details.

  • _j_rules (bool) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow custom Jacobian rules for intermediary primitive dy/dw computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to False to use JVPs or VJPs, via JAX’s jax.jacfwd or jax.jacrev. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to False could improve performance.

  • _s_rules (bool) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow efficient MJJMp rules for structured dy/dw primitive Jacobians. In practice should be set to True, and setting it to False can lead to dramatic deterioration of performance.

  • _fwd (Optional[bool]) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow jax.jvp in intermediary primitive Jacobian dy/dw computations, False to always use jax.vjp. None to decide automatically based on input/output sizes. Applicable when _j_rules=False, or when a primitive does not have a Jacobian rule. Should be set to None for best performance.

Return type:

MonteCarloKernelFn

Returns:

If n_samples is an integer, returns a function of signature kernel_fn(x1, x2, get) that returns an MC estimation of the kernel using n_samples. If n_samples is a collection of integers, kernel_fn(x1, x2, get) returns a generator that yields estimates using n samples for n in n_samples.

Example

>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, _ = stax.serial(
>>>     stax.Conv(128, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(256, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(512, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>> #
>>> n_samples = 200
>>> kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1, n_samples)
>>> kernel = kernel_fn(x_train, x_test, get=('nngp', 'ntk'))
>>> # `kernel` is a tuple of NNGP and NTK MC estimate using `n_samples`.
>>> #
>>> n_samples = [1, 10, 100, 1000]
>>> kernel_fn_generator = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key1,
>>>                                                n_samples)
>>> kernel_samples = kernel_fn_generator(x_train, x_test,
>>>                                      get=('nngp', 'ntk'))
>>> for n, kernel in zip(n_samples, kernel_samples):
>>>   print(n, kernel)
>>>   # `kernel` is a tuple of NNGP and NTK MC estimate using `n` samples.