neural_tangents.stax.Dense
- neural_tangents.stax.Dense(out_dim, W_std=1.0, b_std=None, parameterization='ntk', batch_axis=0, channel_axis=- 1)[source]
Layer constructor function for a dense (fully-connected) layer.
Based on jax.example_libraries.stax.Dense.
- Parameters
out_dim (
int
) – The output feature / channel dimension. This is ignored in by the kernel_fn in “ntk” parameterization.W_std (
float
) – Specifies the standard deviation of the weights.b_std (
Optional
[float
]) – Specifies the standard deviation of the biases. None means no bias.parameterization (
str
) –Either “ntk” or “standard”.
Under “ntk” parameterization (https://arxiv.org/abs/1806.07572, page 3), weights and biases are initialized as \(W_{ij} \sim \mathcal{N}(0,1)\), \(b_i \sim \mathcal{N}(0,1)\), and the finite width layer equation is \(z_i = \sigma_W / \sqrt{N} \sum_j W_{ij} x_j + \sigma_b b_i\).
Under “standard” parameterization (https://arxiv.org/abs/2001.07301), weights and biases are initialized as \(W_{ij} \sim \mathcal{N}(0, W_{std}^2/N)\), \(b_i \sim \mathcal{N}(0,\sigma_b^2)\), and the finite width layer equation is \(z_i = \sum_j W_{ij} x_j + b_i\).
batch_axis (
int
) – Specifies which axis is contains different elements of the batch. Defaults to 0, the leading axis.channel_axis (
int
) – Specifies which axis contains the features / channels. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.
- Return type
Tuple
[InitFn
,ApplyFn
,LayerKernelFn
,MaskFn
]- Returns
(init_fn, apply_fn, kernel_fn).