neural_tangents.empirical_ntk_fn

neural_tangents.empirical_ntk_fn(f, trace_axes=(- 1,), diagonal_axes=(), vmap_axes=None, implementation=1)[source]

Returns a function to draw a single sample the NTK of a given network f.

The Neural Tangent Kernel is defined as \(J(X_1) J(X_2)^T\) where \(J\) is the Jacobian \(df/dparams\) of shape full_output_shape + params.shape.

For best performance: 1) pass x2=None if x1 == x2; 2) prefer square batches (i.e `x1.shape == x2.shape); 3) make sure to set vmap_axes correctly. 4) try different implementation values.

WARNING: Resulting kernel shape is nearly zip(f(x1).shape, f(x2).shape) subject to trace_axes and diagonal_axes parameters, which make certain assumptions about the outputs f(x) that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both trace_axes=() and diagonal_axes=() to obtain the kernel exactly of shape zip(f(x1).shape, f(x2).shape).

For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.

Parameters
  • f (ApplyFn) – the function whose NTK we are computing. f should have the signature f(params, inputs[, rng]) and should return an np.ndarray outputs.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Union[List[int], Tuple[int, ...], int, None, Tuple[Union[List[int], Tuple[int, ...], int, None], Union[List[int], Tuple[int, ...], int, None], Dict[str, Union[List[int], Tuple[int, ...], int, None]]]]) –

    A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes. In other words, it certifies that f can be evaluated as a vmap with out_axes=out_axes over x (along in_axes) and those arguments in **kwargs that are present in kwargs_axes.keys() (along kwargs_axes.values()).

    For example if _, f, _ = nt.stax.Aggregate(), f is called via f(params, x, pattern=pattern). By default, inputs x, patterns pattern, and outputs of f are all batched along the leading 0 dimension, and each output f(params, x, pattern=pattern)[i] only depends on the inputs x[i] and pattern[i]. In this case, we can pass vmap_axes=(0, 0, dict(pattern=0) to specify along which dimensions inputs, outputs, and keyword arguments are batched respectively.

    This allows us to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (int) –

    1 or 2.

    1 directly instantiates Jacobians and computes their outer product.

    2 uses implicit differentiation to avoid instantiating whole Jacobians at once. The implicit kernel is derived by observing that: \(\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)\), i.e. a linear function \([J(X_1) J(X_2)^T]\) applied to an identity matrix \(I\). This allows the computation of the NTK to be phrased as: \(a(v) = J(X_2)^T v\), which is computed by a vector-Jacobian product; \(b(v) = J(X_1) a(v)\) which is computed by a Jacobian-vector product; and \(\Theta = [b(v)] / d[v^T](I)\) which is computed via a vmap of \(b(v)\) over columns of the identity matrix \(I\).

    It is best to benchmark each method on your specific task. We suggest using 1 unless you get OOMs due to large number of trainable parameters, otherwise - 2.

Return type

EmpiricalKernelFn

Returns

A function ntk_fn that computes the empirical ntk.