gsnn.optim.FunctionEdgeInferer
Post-hoc inference of latent function -> function edges in a trained GSNN.
Overview
FunctionEdgeInferer produces a dense (N, N) evidence matrix W
over function nodes from a trained gsnn.models.GSNN.GSNN model and an
evaluation batch. The score W[i, j] is a rank (or Pearson) correlation
between layer-\(l\) per-node activations at node i and
layer-\(l+1\) per-node gradients at node j, pooled across layers and
batch elements. High values are interpreted as evidence for a directed edge
i -> j (a Hebbian-style “what fired together with a downstream credit
signal”).
The pipeline is:
Toggle activation capture on every ResBlock and run a forward + backward pass on the supplied batch.
Read each block’s per-channel pre-/post-norm activations and gradients.
Reduce
|act|/|grad|within each node group (sum/mean/max) to obtain per-node activity at every layer.Stack
(layer-l act, layer-(l+1) grad)pairs across layers + batch asM = (L - 1) * Bobservations.Compute the cross-correlation matrix between the
Nactivation columns and theNgradient columns.
A short mrr() utility is provided to score the resulting matrix against
ground-truth edges via mean-reciprocal-rank.
Tractability
Let
N = number of function nodes L = number of ResBlocks B = batch size used for attribution M = (L - 1) * B = number of (activation, gradient) observations
Dominant costs:
quantity |
memory |
compute |
|---|---|---|
activations / gradients buffers |
|
forward + backward pass |
ranked observations (Spearman) |
|
|
correlation matrix |
|
|
Concrete back-of-envelope figures (float32):
1k nodes, 10k obs -> ~40 MB working set, sub-second on CPU.
5k nodes, 50k obs -> ~1 GB working set, ~seconds on GPU.
10k nodes, 10k obs -> ~400 MB just for
W, seconds on GPU.
- 10k nodes, 100k obs -> ~4 GB per (acts, grads) tensor;
feasible on a 24 GB GPU with chunked matmul.
- 10k nodes, 300k obs -> ~12 GB per tensor; not feasible with the
prior pandas/Spearman implementation (the intermediate
(2N, 2N)float64 correlation alone is ~3.2 GB and the pandas Spearman ranker is effectively single-threaded). The GPU rank-Pearson path in this module runs in well under a minute on a single 24 GB GPU usingrow_chunkto stream the final matmul.
For configurations beyond ~5k nodes the previous pd.DataFrame.corr path
was a hard wall (both compile-time of the full corr and the float64 result).
The implementation here:
keeps everything in float32 on
device;computes Spearman as rank-Pearson via
argsort;only computes the off-diagonal
(a x g)block, never the full(2N x 2N)matrix;exposes
row_chunkso the output matmul can be streamed whenN^2no longer fits;preserves all model state (activation flags,
.gradbuffers, training/eval mode) on exit.
Caveats:
Spearman ties are broken by argsort order rather than by the textbook average-rank convention. For continuous-valued neural-network activations ties are vanishingly rare so the bias is negligible; it can become visible on heavily quantised / ReLU-zeroed activations.
The (activation, gradient) score is a heuristic, not a causal estimator; shared up-stream drivers can inflate apparent edge strength.
Functions
|
@contextmanager decorator. |
|
Mean reciprocal rank of true |
Classes
|
Infer a |
- class gsnn.optim.FunctionEdgeInferer.FunctionEdgeInferer(model, crit, edge_index, use_prenorm: bool = True, device: Union[str, torch.device] = 'cpu', norm: str = 'l1', agg: str = 'sum')[source]
Bases:
objectInfer a
(N, N)function-to-function edge-evidence matrix.See the module docstring for the underlying scoring heuristic and a discussion of tractability.
- Parameters:
model (GSNN) – Trained model whose ResBlocks expose
_store_activations,channel_groupsand recorded activations.crit (callable) – Loss criterion (e.g.
torch.nn.MSELoss).edge_index (LongTensor of shape
(2, E)) – Known directed edges in the function-to-function subgraph; used by_penalize_dependencies()and made available to callers that want to compareWagainst the prior graph.use_prenorm (bool, default
True) – Use pre-norm activations rather than post-everything ones.device (torch device used for the forward/backward pass and correlation.) –
norm (
'l1','l2'or'none'applied to|act|/|grad|.) –agg (
'sum','mean'or'max'reduction within each node group.) –
- fit(x, y, method: str = 'spearman', penalty_factor: float = 0.0, scale_by_act_mean: bool = False, estimate: Optional[bool] = False, estimate_iters: int = 10, estimate_n_samples: int = 2500, row_chunk: Optional[int] = None, verbose: bool = False) numpy.ndarray[source]
Compute the edge-evidence matrix, optionally with bootstrap averaging.
- Parameters:
x (Tensor) – Forward-pass inputs and targets (typically held-out data).
y (Tensor) – Forward-pass inputs and targets (typically held-out data).
method (
'spearman'(default) or'pearson'.) –penalty_factor (float, default
0.) – If non-zero, post-multipliesWvia_penalize_dependencies()to soften scores around known edges.scale_by_act_mean (bool, default
False) – Multiply rows ofWby the mean source-node activation. Useful when correlations are dominated by rarely-firing nodes.estimate (bool, default
False) – If truthy, sampleestimate_n_samplesobservations with replacementestimate_iterstimes and average the resulting matrices.estimate_n_samples (int, default
2500) – Number of observations to sample for each estimate.estimate_iters (int, default
10) – Number of estimates to average.row_chunk (optional int) – If set, the
(N, N)matmul inside_fast_corr()is computed in row chunks of this size. Use to bound peak memory at largeN(rough guide:row_chunk * (N + M) * 4 bytes).verbose (bool, default
False) – Print a progress line during bootstrap.
- Returns:
W – Score matrix;
W[i, j]is evidence for edgei -> j.- Return type:
ndarray of shape
(N, N)
- gsnn.optim.FunctionEdgeInferer.mrr(A, edge_index, all_true_edges, query_chunk_size: int = 2048) float[source]
Mean reciprocal rank of true
(src -> dst)edges scored byA.For each query edge
(s, d)inedge_indexthe rank ofA[s, d]is computed only against false candidate tailsA[s, j]– i.e. tailsjsuch that(s, j)is not present inall_true_edges. Other true edges that share the same source as the query are masked out of the candidate pool, so true edges never compete against one another. Ties are broken with the average-rank convention. Returns a scalar in \((0, 1]\).- Parameters:
A (Tensor of shape
(N, N)) – Score matrix; larger values indicate stronger evidence of an edge.edge_index (LongTensor of shape
(2, E_q)) – Query edges whose ranks contribute to the MRR.all_true_edges (LongTensor of shape
(2, E_total)) – Union of all known/true edges (training + validation + test). Used to build the “competing-true” mask.query_chunk_size (int, default
2048) – Number of query edges processed at once. Tunes theO(query_chunk_size * N)peak working set; lower this if you OOM on very largeNorE_q.
- Returns:
Mean reciprocal rank across all query edges.
- Return type: