neural_tangents.stax.supports_masking
- neural_tangents.stax.supports_masking(remask_kernel)[source]
Returns a decorator that turns layers into layers supporting masking.
Specifically:
init_fn
is left unchanged.
2.
apply_fn
is turned from a function that accepts amask=None
keyword argument (which indicatesinputs[mask]
must be masked), into a function that accepts amask_constant=None
keyword argument (which indicatesinputs[inputs == mask_constant]
must be masked).kernel_fn
is modified to
3.a. propagate the
kernel.mask1
andkernel.mask2
through intermediary layers, and,3.b. if
remask_kernel == True
, zeroes-out covariances between entries of which at least one is masked.4. If the decorated layers has a
mask_fn
, it is used to propagate masks forward through the layer, in bothapply_fn
andkernel_fn
. If not, it is assumed the mask remains unchanged.Must be applied before the
layer
decorator.See also
Example of masking application in
examples/imdb.py
.- Parameters:
remask_kernel (
bool
) – True to zero-out kernel covariance entries between masked inputs after applying kernel_fn. Some layers don’t need this and setting remask_kernel=False can save compute.- Returns:
A decorator that turns functions returning (init_fn, apply_fn, kernel_fn[, mask_fn]) into functions returning (init_fn, apply_fn_with_masking, kernel_fn_with_masking).