Empirical – finite NNGP and NTK
Compute empirical NNGP and NTK; approximate functions via Taylor series.
All functions in this module are applicable to any JAX functions of proper
signatures (not only those from nt.stax
).
NNGP and NTK are computed using nt.empirical_nngp_fn
, nt.empirical_ntk_fn
,
or nt.empirical_kernel_fn
(for both). The kernels have a very specific output
shape convention that may be unexpected. Further, NTK has multiple
implementations that may perform differently depending on the task. Please read
individual functions’ docstrings.
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>>
>>> key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>> # A narrow CNN.
>>> init_fn, f, _ = stax.serial(
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>>
>>> _, params = init_fn(key3, x_train.shape)
>>>
>>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>> # network is iid along the batch axis, no BatchNorm. Use default
>>> # `implementation=1` since the network has few trainable parameters.
>>> kernel_fn = nt.empirical_kernel_fn(
>>> f, trace_axes=(-1,), vmap_axes=0, implementation=1)
>>>
>>> # (5, 20) np.ndarray test-train NNGP/NTK
>>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>>
>>> # Full kernel: not reducing over logits.
>>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0)
>>>
>>> # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple.
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>> # A wide FCN with lots of parameters
>>> init_fn, f, _ = stax.serial(
>>> stax.Flatten(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(10)
>>> )
>>>
>>> _, params = init_fn(key3, x_train.shape)
>>>
>>> # Use implicit differentiation in NTK: `implementation=2` to reduce
>>> # memory cost, since the network has many trainable parameters.
>>> ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2)
>>>
>>> # (5, 5) np.ndarray test-test NTK
>>> ntk_test_test = ntk_fn(x_test, None, params)
>>>
>>> # Compute only output variances:
>>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>>
>>> # (20,) np.ndarray train-train diagonal NNGP
>>> nngp_train_train_diag = nngp_fn(x_train, None, params)
Kernel functions
Finite-width NNGP and/or NTK kernel functions.
|
Returns a function that computes single draws from NNGP and NT kernels. |
|
Returns a function to draw a single sample the NNGP of a given network |
|
Returns a function to draw a single sample the NTK of a given network |
Linearization and Taylor expansion
Decorators to Taylor-expand around function parameters.
|
Returns a function |
|
Returns a function |