gsnn.optim.FunctionEdgeInferer

Post-hoc inference of latent function -> function edges in a trained GSNN.

Overview

FunctionEdgeInferer produces a dense (N, N) evidence matrix W over function nodes from a trained gsnn.models.GSNN.GSNN model and an evaluation batch. The score W[i, j] is a rank (or Pearson) correlation between layer-\(l\) per-node activations at node i and layer-\(l+1\) per-node gradients at node j, pooled across layers and batch elements. High values are interpreted as evidence for a directed edge i -> j (a Hebbian-style “what fired together with a downstream credit signal”).

The pipeline is:

  1. Toggle activation capture on every ResBlock and run a forward + backward pass on the supplied batch.

  2. Read each block’s per-channel pre-/post-norm activations and gradients.

  3. Reduce |act| / |grad| within each node group (sum / mean / max) to obtain per-node activity at every layer.

  4. Stack (layer-l act, layer-(l+1) grad) pairs across layers + batch as M = (L - 1) * B observations.

  5. Compute the cross-correlation matrix between the N activation columns and the N gradient columns.

A short mrr() utility is provided to score the resulting matrix against ground-truth edges via mean-reciprocal-rank.

Tractability

Let

N = number of function nodes L = number of ResBlocks B = batch size used for attribution M = (L - 1) * B = number of (activation, gradient) observations

Dominant costs:

quantity

memory

compute

activations / gradients buffers

O(L * B * N)

forward + backward pass

ranked observations (Spearman)

O(M * N)

N * O(M log M)

correlation matrix W

O(N^2)

O(M * N^2)

Concrete back-of-envelope figures (float32):

  • 1k nodes, 10k obs -> ~40 MB working set, sub-second on CPU.

  • 5k nodes, 50k obs -> ~1 GB working set, ~seconds on GPU.

  • 10k nodes, 10k obs -> ~400 MB just for W, seconds on GPU.

  • 10k nodes, 100k obs -> ~4 GB per (acts, grads) tensor;

    feasible on a 24 GB GPU with chunked matmul.

  • 10k nodes, 300k obs -> ~12 GB per tensor; not feasible with the

    prior pandas/Spearman implementation (the intermediate (2N, 2N) float64 correlation alone is ~3.2 GB and the pandas Spearman ranker is effectively single-threaded). The GPU rank-Pearson path in this module runs in well under a minute on a single 24 GB GPU using row_chunk to stream the final matmul.

For configurations beyond ~5k nodes the previous pd.DataFrame.corr path was a hard wall (both compile-time of the full corr and the float64 result). The implementation here:

  • keeps everything in float32 on device;

  • computes Spearman as rank-Pearson via argsort;

  • only computes the off-diagonal (a x g) block, never the full (2N x 2N) matrix;

  • exposes row_chunk so the output matmul can be streamed when N^2 no longer fits;

  • preserves all model state (activation flags, .grad buffers, training/eval mode) on exit.

Caveats:

  • Spearman ties are broken by argsort order rather than by the textbook average-rank convention. For continuous-valued neural-network activations ties are vanishingly rare so the bias is negligible; it can become visible on heavily quantised / ReLU-zeroed activations.

  • The (activation, gradient) score is a heuristic, not a causal estimator; shared up-stream drivers can inflate apparent edge strength.

Functions

contextmanager(func)

@contextmanager decorator.

mrr(A, edge_index, all_true_edges[, ...])

Mean reciprocal rank of true (src -> dst) edges scored by A.

Classes

FunctionEdgeInferer(model, crit, edge_index)

Infer a (N, N) function-to-function edge-evidence matrix.

class gsnn.optim.FunctionEdgeInferer.FunctionEdgeInferer(model, crit, edge_index, use_prenorm: bool = True, device: Union[str, torch.device] = 'cpu', norm: str = 'l1', agg: str = 'sum')[source]

Bases: object

Infer a (N, N) function-to-function edge-evidence matrix.

See the module docstring for the underlying scoring heuristic and a discussion of tractability.

Parameters:
  • model (GSNN) – Trained model whose ResBlocks expose _store_activations, channel_groups and recorded activations.

  • crit (callable) – Loss criterion (e.g. torch.nn.MSELoss).

  • edge_index (LongTensor of shape (2, E)) – Known directed edges in the function-to-function subgraph; used by _penalize_dependencies() and made available to callers that want to compare W against the prior graph.

  • use_prenorm (bool, default True) – Use pre-norm activations rather than post-everything ones.

  • device (torch device used for the forward/backward pass and correlation.) –

  • norm ('l1', 'l2' or 'none' applied to |act| / |grad|.) –

  • agg ('sum', 'mean' or 'max' reduction within each node group.) –

fit(x, y, method: str = 'spearman', penalty_factor: float = 0.0, scale_by_act_mean: bool = False, estimate: Optional[bool] = False, estimate_iters: int = 10, estimate_n_samples: int = 2500, row_chunk: Optional[int] = None, verbose: bool = False) numpy.ndarray[source]

Compute the edge-evidence matrix, optionally with bootstrap averaging.

Parameters:
  • x (Tensor) – Forward-pass inputs and targets (typically held-out data).

  • y (Tensor) – Forward-pass inputs and targets (typically held-out data).

  • method ('spearman' (default) or 'pearson'.) –

  • penalty_factor (float, default 0.) – If non-zero, post-multiplies W via _penalize_dependencies() to soften scores around known edges.

  • scale_by_act_mean (bool, default False) – Multiply rows of W by the mean source-node activation. Useful when correlations are dominated by rarely-firing nodes.

  • estimate (bool, default False) – If truthy, sample estimate_n_samples observations with replacement estimate_iters times and average the resulting matrices.

  • estimate_n_samples (int, default 2500) – Number of observations to sample for each estimate.

  • estimate_iters (int, default 10) – Number of estimates to average.

  • row_chunk (optional int) – If set, the (N, N) matmul inside _fast_corr() is computed in row chunks of this size. Use to bound peak memory at large N (rough guide: row_chunk * (N + M) * 4 bytes).

  • verbose (bool, default False) – Print a progress line during bootstrap.

Returns:

W – Score matrix; W[i, j] is evidence for edge i -> j.

Return type:

ndarray of shape (N, N)

gsnn.optim.FunctionEdgeInferer.mrr(A, edge_index, all_true_edges, query_chunk_size: int = 2048) float[source]

Mean reciprocal rank of true (src -> dst) edges scored by A.

For each query edge (s, d) in edge_index the rank of A[s, d] is computed only against false candidate tails A[s, j] – i.e. tails j such that (s, j) is not present in all_true_edges. Other true edges that share the same source as the query are masked out of the candidate pool, so true edges never compete against one another. Ties are broken with the average-rank convention. Returns a scalar in \((0, 1]\).

Parameters:
  • A (Tensor of shape (N, N)) – Score matrix; larger values indicate stronger evidence of an edge.

  • edge_index (LongTensor of shape (2, E_q)) – Query edges whose ranks contribute to the MRR.

  • all_true_edges (LongTensor of shape (2, E_total)) – Union of all known/true edges (training + validation + test). Used to build the “competing-true” mask.

  • query_chunk_size (int, default 2048) – Number of query edges processed at once. Tunes the O(query_chunk_size * N) peak working set; lower this if you OOM on very large N or E_q.

Returns:

Mean reciprocal rank across all query edges.

Return type:

float