neural_tangents.stax.Aggregate

neural_tangents.stax.Aggregate(aggregate_axis=None, batch_axis=0, channel_axis=-1, to_dense=<function <lambda>>, implementation='DENSE')[source]

Layer constructor for aggregation operator (graphical neural network).

See e.g. https://arxiv.org/abs/1905.13192.

Specifically, each N+2-D input of shape (batch, X_1, …, X_N, channels) (subject to batch_axis and channel_axis) is accompanied by an array pattern specifying the directed edges (arcs, arrows) of the graph. The format of pattern depends on implementation:

implementation = “DENSE”:

Is recommended for dense graphs, where the number of edges E is proportional to the number of vertices V to the power of 1.5 or more. In this case, pattern is a [weighted] adjacency 2-adjacency 2K+1-D tensor of shape (batch, X_i1, …, X_iK, X_i1, …, X_iK) (i.e. leading batch dimensions, repeated spatial dimensions, no channel dimension) and the output tensor is lax.dot_general(inputs, pattern, ((aggregate_axes, range(1, K + 1)), (batch_axis,), (0,))) with the batch_axis and channel_axis preserved. K = len(aggregate_axes).

Having pattern[n, i1, …, iK, j1, …, jK] == w represents a directed edge (arc) from tail pixel / token (i1, …, iK) to head (j1, …, jK) with weight w in an individual input sample n. The apply_fn of this layer replaces all vertices with the (weighted) sum of all direct predecessors to the given vertex.

Note that individual inputs can have more than K dimensions (e.g. channels, other coordinates), in which case slices along these coordinates are processed in the same way independently.

This implementation uses matrix multiplication, and for a graph with V vertices and E edges, apply_fn costs O(V^2) memory and time, while kernel_fn costs O(V^2) memory and O(V^3) time.

The adjacency tensor pattern can be specified in a sparse format. If you provide a to_dense function (defaults to identity), then pattern is decoded into a dense representation as described above (pattern_dense = to_dense(pattern)) each time apply_fn or kernel_fn are called. This avoids storing the whole graph in the dense format in advance, but only convert it to dense format on the fly, for each individual batch x / (x1, x2). However, this does not improve the runtime or memory of the Aggregate layer (in fact makes it a bit slower due to an extra to_dense call).

implementation = “SPARSE”:

Is recommended for sparse graphs, where E ~ O(V) or less. In this case, pattern must be an integer array of shape (batch, n_edges, K, 2), specifying n_edges directed edges (arcs) of weight w = 1 for each of the batch input samples (if K == 1 pattern can also have the shape (batch, n_edges, 2)). Trailing dimension of size 2 corresponds to tails (sources, senders) and heads (targets, receivers). Edges can be repeated, which is interpreted as having their weight be the number of repetitions. If any of the K coordinates of a given vertex in heads is negative (e.g. -1), it is discarded. This can be used for padding, when different input samples have different n_edges. Note that this means you can’t use negative indexing to specify vertices.

This implementation uses jax.ops.segment_sum instead of matrix multiplication. This makes apply_fn cost O(V + E) memory and O(V + E) time, and kernel_fn cost O(V^2) memory and O(V^2 + E^2 + V * E) time. This is beneficial for sparse graphs, i.e. E << V^2, but detrimental for dense graphs (when E ~ V^2).

See also

AggregateTest in tests/stax_test.py for examples and conversion between sparse and dense patterns.

Example

>>>  # 1D inputs
>>>  x = random.normal(random.PRNGKey(1), (5, 3, 32))  # NCH
>>>
>>>  # 1) NHH dense binary adjacency matrix
>>>  A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>>  # `A[n, h1, h2] == True`
>>>  # means an edge between tokens `h1` and `h2` in sample `n`.
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>>                                                batch_axis=0,
>>>                                                channel_axis=1)
>>>
>>>  out = apply_fn((), x, pattern=A)
>>>  # output is the same as `x @ A` of shape (5, 3, 32)
>>>
>>>  # Sparse NHH binary pattern with 10 edges
>>>  n_edges = 10
>>>  A_sparse = random.randint(random.PRNGKey(3),
>>>                            shape=(x.shape[0], n_edges, 1, 2),
>>>                            minval=0,
>>>                            maxval=x.shape[2])
>>>
>>>  # Setting `implementation="SPARSE"` to invoke the segment sum
>>>  # implementation.
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>>                                                batch_axis=0,
>>>                                                channel_axis=1,
>>>                                                implementation="SPARSE")
>>>
>>>  out = apply_fn((), x, pattern=A_sparse)
>>>  # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`.
>>>
>>>  # 2D inputs
>>>  x = random.normal(random.PRNGKey(1), (5, 3, 32, 16))  # NCHW
>>>
>>>  # 2) NHWHW dense binary adjacency matrix
>>>  A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16))
>>>  # `A[n, h1, w1, h2, w2] == True`
>>>  # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`.
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3),
>>>                                                batch_axis=0,
>>>                                                channel_axis=1)
>>>
>>>  out = apply_fn((), x, pattern=A)
>>>  # output is of shape (5, 3, 32, 16), the same as
>>>  # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16))
>>>  #  ).reshape(x.shape)`
>>>
>>>
>>>  # 3) NWW binary adjacency matrix
>>>  A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16))
>>>  # `A[n, w1, w2] == True`
>>>  # means an edge between rows `w1` and `w2` in image `n`.
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,),
>>>                                                batch_axis=0,
>>>                                                channel_axis=1)
>>>
>>>  out = apply_fn((), x, pattern=A)
>>>  # output is of shape (5, 3, 32, 16), the same as
>>>  # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)`
>>>
>>>
>>>  # 4) Infinite width example
>>>  x1 = random.normal(random.PRNGKey(1), (5, 3, 32))  # NCH
>>>  x2 = random.normal(random.PRNGKey(2), (2, 3, 32))  # NCH
>>>
>>>  # NHH binary adjacency matrices
>>>  A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>>  A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32))
>>>
>>>  _, _, kernel_fn_id = stax.Identity()
>>>
>>>  _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2,
>>>                                       batch_axis=0,
>>>                                       channel_axis=1)
>>>
>>>  nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1)
>>>  # initial NNGP of shape (5, 2, 32, 32)
>>>  K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2))
>>>  # output NNGP of same shape (5, 2, 32, 32):
>>>  # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]`
Parameters
  • aggregate_axis (Union[int, Sequence[int], None]) – axes (non-batch and non-channel) to aggregate predecessor vertices over.

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

  • to_dense (Optional[Callable[[ndarray], ndarray]]) – Ignored unless implementation == “DENSE”. A function to convert potentially sparse pattern matrices into dense 2K+1-D tensors of shape (batch, X_i1, …, X_iK, X_i1, …, X_iK), with the batch leading dimension, and no channel dimension, where K = len(aggregate_axes). Will be called on input pattern (or a pair (pattern1, pattern2)) every time apply_fn or kernel_fn is called. Defaults to identity, meaning that pattern is expected in the dense format.

  • implementation (str) – “DENSE” or “SPARSE”, specifying which implementation to use. “DENSE” uses matrix multiplications and is recommended for dense graphs (E ~> O(V^1.5)), while “SPARSE” uses jax.ops.segment_sum and is recommended for sparse graphs (E ~< O(V)). Note that different implementation`s require different `pattern array format - see the layer docstring above for details.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn]

Returns

(init_fn, apply_fn, kernel_fn).