Empirical – finite NNGP and NTK

Compute empirical NNGP and NTK; approximate functions via Taylor series.

All functions in this module are applicable to any JAX functions of proper signatures (not only those from nt.stax).

NNGP and NTK are computed using nt.empirical_nngp_fn, nt.empirical_ntk_fn, or nt.empirical_kernel_fn (for both). The kernels have a very specific output shape convention that may be unexpected. Further, NTK has multiple implementations that may perform differently depending on the task. Please read individual functions’ docstrings.

Example

>>>  from jax import random
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>>  x_train = random.normal(key1, (20, 32, 32, 3))
>>>  y_train = random.uniform(key1, (20, 10))
>>>  x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>>  # A narrow CNN.
>>>  init_fn, f, _ = stax.serial(
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Flatten(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  _, params = init_fn(key3, x_train.shape)
>>>
>>>  # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>>  # network is iid along the batch axis, no BatchNorm. Use default
>>>  # `implementation=1` since the network has few trainable parameters.
>>>  kernel_fn = nt.empirical_kernel_fn(
>>>      f, trace_axes=(-1,), vmap_axes=0, implementation=1)
>>>
>>>  # (5, 20) np.ndarray test-train NNGP/NTK
>>>  nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>>  ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>>
>>>  # Full kernel: not reducing over logits.
>>>  kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0)
>>>
>>>  # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple.
>>>  k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>>  # A wide FCN with lots of parameters
>>>  init_fn, f, _ = stax.serial(
>>>      stax.Flatten(),
>>>      stax.Dense(1024),
>>>      stax.Relu(),
>>>      stax.Dense(1024),
>>>      stax.Relu(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  _, params = init_fn(key3, x_train.shape)
>>>
>>>  # Use implicit differentiation in NTK: `implementation=2` to reduce
>>>  # memory cost, since the network has many trainable parameters.
>>>  ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2)
>>>
>>>  # (5, 5) np.ndarray test-test NTK
>>>  ntk_test_test = ntk_fn(x_test, None, params)
>>>
>>>  # Compute only output variances:
>>>  nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>>
>>>  # (20,) np.ndarray train-train diagonal NNGP
>>>  nngp_train_train_diag = nngp_fn(x_train, None, params)

Kernel functions

Finite-width NNGP and/or NTK kernel functions.

empirical_kernel_fn(f[, trace_axes, ...])

Returns a function that computes single draws from NNGP and NT kernels.

empirical_nngp_fn(f[, trace_axes, diagonal_axes])

Returns a function to draw a single sample the NNGP of a given network f.

empirical_ntk_fn(f[, trace_axes, ...])

Returns a function to draw a single sample the NTK of a given network f.

Linearization and Taylor expansion

Decorators to Taylor-expand around function parameters.

linearize(f, params)

Returns a function f_lin, the first order taylor approximation to f.

taylor_expand(f, params, degree)

Returns a function f_tayl, Taylor approximation to f of order degree.