neural_tangents.stax.parallel

neural_tangents.stax.parallel(*layers)[source]

Combinator for composing layers in parallel.

The layer resulting from this combinator is often used with the FanOut, FanInSum, and FanInConcat layers. Based on jax.example_libraries.stax.parallel.

Parameters:

*layers (tuple[InitFn, ApplyFn, AnalyticKernelFn]) – a sequence of layers, each with a (init_fn, apply_fn, kernel_fn) triple.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]

Returns:

A new layer, meaning an (init_fn, apply_fn, kernel_fn) triples, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument layers.