neural_tangents.empirical_ntk_vp_fn
- neural_tangents.empirical_ntk_vp_fn(f, x1, x2, params, **apply_fn_kwargs)[source]
Returns an NTK-vector product function.
The function computes NTK-vector product without instantiating the NTK, and has the runtime equivalent to
(N1 + N2)
forward passes throughf
, and memory equivalent to evaluating a vector-Jacobian product off
.For details, please see section L of “Fast Finite Width Neural Tangent Kernel”.
Example
>>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> # >>> k1, k2, k3, k4 = random.split(random.PRNGKey(1), 4) >>> x1 = random.normal(k1, (20, 32, 32, 3)) >>> x2 = random.normal(k2, (10, 32, 32, 3)) >>> # >>> # Define a forward-pass function `f`. >>> init_fn, f, _ = stax.serial( >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> # >>> # Initialize parameters. >>> _, params = init_fn(k3, x1.shape) >>> # >>> # NTK-vp function. Can/should be JITted. >>> ntk_vp_fn = empirical_ntk_vp_fn(f, x1, x2, params) >>> # >>> # Cotangent vector >>> cotangents = random.normal(k4, f(params, x2).shape) >>> # >>> # NTK-vp output >>> ntk_vp = ntk_vp_fn(cotangents) >>> # >>> # Output has same shape as `f(params, x1)`. >>> assert ntk_vp.shape == f(params, x1).shape
- Parameters:
f (
ApplyFn
) – forward-pass function of signature f(params, x).x1 (
Any
) – first batch of inputs.x2 (
Optional
[Any
]) – second batch of inputs. x2=None means x2=x1.params (
Any
) – A PyTree of parameters about which we would like to compute the neural tangent kernel.**apply_fn_kwargs – keyword arguments passed to f. apply_fn_kwargs will be split into apply_fn_kwargs1 and apply_fn_kwargs2 by the split_kwargs function which will be passed to f. In particular, the rng key in apply_fn_kwargs, will be split into two different (if x1!=x2) or same (if x1==x2) rng keys. See the _read_key function for more details.
- Return type:
- Returns:
An NTK-vector product function accepting a PyTree of cotangents of shape and structure of f(params, x2), and returning the NTK-vector product of shape and structure of f(params, x1).