neural_tangents.predict.max_learning_rate
- neural_tangents.predict.max_learning_rate(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e-12)[source]
Computes the maximal feasible learning rate for infinite width NNs.
The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form 1/(2 * batch_size * output_size) |f(train_x) - train_y|^2. For vanilla SGD (i.e. momentum = 0) the maximal feasible learning rate is the largest eta such that the operator (I - eta / (batch_size * output_size) * NTK) is a contraction, which is 2 * batch_size * output_size * lambda_max(NTK). When momentum > 0, we use 2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK) (see The Dynamics of Momentum section in https://distill.pub/2017/momentum/).
- Parameters
ntk_train_train (
ndarray
) – analytic or empirical NTK on the training data.y_train_size (
Optional
[int
]) – total training set output size, i.e. f(x_train).size == y_train.size. If output_size=None it is inferred from ntk_train_train.shape assuming trace_axes=().momentum – The momentum for momentum optimizers.
eps (
float
) – a float to avoid zero divisor.
- Return type
- Returns
The maximal feasible learning rate for infinite width NNs.