Public API:
nt.stax
FanInSum()
nt.empirical
nt.predict
nt.batch
nt.monte_carlo_kernel_fn
Internal:
nt.experimental
Kernel
Colab Examples:
Papers:
Other Resources:
Fan-in sum.
This layer takes a number of inputs (e.g. produced by FanOut) and sums the inputs to produce a single output. Based on jax.example_libraries.stax.FanInSum.
FanOut
jax.example_libraries.stax.FanInSum
tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]
tuple
InitFn
ApplyFn
LayerKernelFn
MaskFn
(init_fn, apply_fn, kernel_fn).