neural_tangents.predict.gp_inference

neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1,))[source]

Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.

NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see https://arxiv.org/abs/2007.05864 for how this can correspond to infinite ensembles of infinitely wide NNs as well).

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization of k_train_train.nngp or k_train_train.ntk (or both) is performed and cached for future invocations.

Parameters
  • k_train_train – train-train kernel. Can be (a) np.ndarray, (b) Kernel namedtuple, (c) Kernel object. Must contain the necessary nngp and/or ntk kernels for arguments provided to the returned predict_fn function. For example, if you request to compute posterior test [only] NTK covariance in future predict_fn invocations, k_train_train must contain both ntk and nngp kernels.

  • y_train (ndarray) – train targets.

  • diag_reg (float) – a scalar representing the strength of the diagonal regularization for k_train_train, i.e. computing k_train_train + diag_reg * I during Cholesky factorization.

  • diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * np.mean(np.trace(k_train_train)).

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train, k_test_train`[, and `k_test_test] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Returns

A function of signature predict_fn(get, k_test_train, k_test_test) computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.