neural_tangents.linearize

neural_tangents.linearize(f, params)[source]

Returns a function f_lin, the first order taylor approximation to f.

Example

>>> # Compute the MSE of the first order Taylor series of a function.
>>> f_lin = linearize(f, params)
>>> mse = np.mean((f(new_params, x) - f_lin(new_params, x)) ** 2)
Parameters
  • f (ApplyFn) – A function that we would like to linearize. It should have the signature f(params, *args, **kwargs) where params is a PyTree and f should return a PyTree.

  • params (Any) – Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations.

Return type

ApplyFn

Returns

A function f_lin(new_params, *args, **kwargs) whose signature is the same as f. Here f_lin implements the first-order taylor series of f about params.