neural_tangents.stax.ElementwiseNumerical
- neural_tangents.stax.ElementwiseNumerical(fn, deg, df=None)[source]
Activation function using numerical integration.
Supports general activation functions using Gauss-Hermite quadrature.
For details, please see “Fast Neural Kernel Embeddings for General Activations”.
See also
examples/elementwise_numerical.py
.- Parameters:
deg (
int
) – number of sample points and weights for quadrature. It must be >= 1. We observe for smooth activations deg=25 is a good place to start. For non-smooth activation functions (e.g. ReLU, Abs) quadrature is not recommended (for now use nt.monte_carlo_kernel_fn). Due to bivariate integration, compute time and memory scale as O(deg**2) for more precision. See eq (13) in https://mathworld.wolfram.com/Hermite-GaussQuadrature.html for error estimates in the case of 1d Gauss-Hermite quadrature.df (
Optional
[Callable
[[float
],float
]]) – optional, derivative of the activation function (fn). If not provided, it is computed by jax.grad. Providing analytic derivative can speed up the NTK computations.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).