neural_tangents.stax.supports_masking
- neural_tangents.stax.supports_masking(remask_kernel)[source]
Returns a decorator that turns layers into layers supporting masking.
Specifically:
init_fn is left unchanged.
2. apply_fn is turned from a function that accepts a mask=None keyword argument (which indicates inputs[mask] must be masked), into a function that accepts a mask_constant=None keyword argument (which indicates inputs[inputs == mask_constant] must be masked).
kernel_fn is modified to
3.a. propagate the kernel.mask1 and kernel.mask2 through intermediary layers, and,
3.b. if remask_kernel == True, zeroes-out covariances between entries of which at least one is masked.
4. If the decorated layers has a mask_fn, it is used to propagate masks forward through the layer, in both apply_fn and kernel_fn. If not, it is assumed the mask remains unchanged.
Must be applied before the layer decorator.
- Parameters
remask_kernel (
bool
) – True to zero-out kernel covariance entries between masked inputs after applying kernel_fn. Some layers don’t need this and setting remask_kernel=False can save compute.- Returns
A decorator that turns functions returning (init_fn, apply_fn, kernel_fn[, mask_fn]) into functions returning (init_fn, apply_fn_with_masking, kernel_fn_with_masking).