nt.predict
– inference w/ NNGP & NTK
Functions to make predictions on the train/test set using NTK/NNGP.
Most functions in this module accept training data as inputs and return a new
function predict_fn
that computes predictions on the train set / given test
set / timesteps.
Warning
trace_axes
parameter supplied to prediction functions must match the
respective parameter supplied to the function used to compute the kernel.
Namely, this is the same trace_axes
used to compute the empirical kernel
(utils/empirical.py
; diagonal_axes
must be ()
), or channel_axis
in the
output of the top layer used to compute the closed-form kernel (stax.py
;
note that closed-form kernels currently only support a single channel_axis
).
Prediction / inference functions
Functions to make train/test set predictions given NNGP/NTK kernels or the linearized function.
|
Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP. |
|
Predicts the outcome of function space training using gradient descent. |
|
Predicts the outcome of function space gradient descent training on MSE. |
|
Predicts the gaussian embedding induced by gradient descent on MSE loss. |
Utilities
|
Computes the maximal feasible learning rate for infinite width NNs. |
Helper classes
Dataclasses and namedtuples used to return predictions.
|
A |
|
ODE state dataclass holding outputs and auxiliary variables. |