Typing
Common Type Definitions.
- class neural_tangents._src.utils.typing.AnalyticKernelFn(*args, **kwargs)[source]
A type alias for analytic kernel functions.
A kernel function that computes an analytic kernel. Takes either a
Kernelorjax.numpy.ndarrayinputs and agetargument that specifies what quantities should be computed by the kernel. Returns either aKernelobject orjax.numpy.ndarray-s for kernels specified byget.
- class neural_tangents._src.utils.typing.ApplyFn(*args, **kwargs)[source]
A type alias for apply functions.
Apply functions do computations with finite-width neural networks. They are functions that take a PyTree of parameters and an array of inputs and produce an array of outputs.
- neural_tangents._src.utils.typing.Axes
Axes specification, can be integers (
axis=-1) or sequences (axis=(1, 3)).
- class neural_tangents._src.utils.typing.EmpiricalGetKernelFn(*args, **kwargs)[source]
A type alias for empirical kernel functions accepting a
getargument.A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.
Equivalent to
EmpiricalKernelFn, but accepts agetargument, which can be for exampleget=("nngp", "ntk"), to compute both kernels together.
- class neural_tangents._src.utils.typing.EmpiricalKernelFn(*args, **kwargs)[source]
A type alias for empirical kernel functions computing either NTK or NNGP.
A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.
Equivalent to
EmpiricalGetKernelFnwithget="nngp"orget="ntk".
- class neural_tangents._src.utils.typing.InitFn(*args, **kwargs)[source]
A type alias for initialization functions.
Initialization functions construct parameters for neural networks given a random key and an input shape. Specifically, they produce a tuple giving the output shape and a PyTree of parameters.
- neural_tangents._src.utils.typing.Kernels
Kernel inputs/outputs of
FanOut,FanInSum, etc.
- class neural_tangents._src.utils.typing.LayerKernelFn(*args, **kwargs)[source]
A type alias for pure kernel functions.
A pure kernel function takes a PyTree of Kernel object(s) and produces a PyTree of Kernel object(s). These functions are used to define new layer types.
- class neural_tangents._src.utils.typing.MaskFn(*args, **kwargs)[source]
A type alias for a masking functions.
Forward-propagate a mask in a layer of a finite-width network.
- class neural_tangents._src.utils.typing.MonteCarloKernelFn(*args, **kwargs)[source]
A type alias for Monte Carlo kernel functions.
A kernel function that produces an estimate of an
AnalyticKernelby monte carlo sampling given aPRNGKey.
- neural_tangents._src.utils.typing.NTTree
Neural Tangents Tree.
Trees of kernels and arrays naturally emerge in certain neural network computations (for example, when neural networks have nested parallel layers).
Mimicking JAX, we use a lightweight tree structure called an
NTTree.NTTreehas internal nodes that are either lists or tuples and leaves which are eitherjax.numpy.ndarrayorKernelobjects.