neural_tangents.stax.FanOut

neural_tangents.stax.FanOut(num)[source]

Layer construction function for a fan-out layer.

This layer takes an input and produces num copies that can be fed into different branches of a neural network (for example with residual connections).

Parameters

num (int) – The number of going edges to fan out into.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn]

Returns

(init_fn, apply_fn, kernel_fn).