gsnn.gsnn.optim.MagnitudeEdgeInferer
Tier-0 edge inference via magnitude correlation between node activations and node gradients.
For each function node i (source) and j (target), per adjacent layer pair (n-1, n):
x_i^{n-1} = ||z_i^{n-1}||_p (activation magnitude at layer n-1) y_j^n = ||grad z_j^n L||_p (gradient magnitude at layer n)
- Score (
score='corr'): s_n(i -> j) = corr_b(x_i^{n-1}, y_j^n)
Partial-correlation score (score='partial') controls for the kept parents
of j in the partial graph G_partial:
S_j = parents_{G_partial}(j) s_n(i -> j) = pcorr_b( x_i^{n-1}, y_j^n | { x_k^{n-1} : k in S_j } )
This removes the contribution of edges the model already uses, mitigating transitive / common-cause confounds that inflate the plain correlation score.
This matches information flow: activations propagate forward n-1 -> n, gradients
propagate backward n -> n-1. Pairs are aggregated (mean / max) and tested with
Fisher-z p-values + BH-FDR (df = n - 3 - |S_j| for partial correlation).
Requires a trained GSNN with gradient checkpointing disabled (checkpoint=False).
Uses ResBlock._last_pre_norm_activation (post-lin_in, pre-norm) so magnitude
scores remain meaningful even with layer / RMS normalization.
See docs/notes/edge_inference_notes.md section 4 for design rationale.
Classes
|
Post-hoc inferrer for function -> function edges via activation/gradient magnitude correlation across adjacent layers. |
- class gsnn.gsnn.optim.MagnitudeEdgeInferer.MagnitudeEdgeInferer(model, data, reduction: Literal['l1', 'l2'] = 'l1', use_pre_norm: bool = True)[source]
Bases:
objectPost-hoc inferrer for function -> function edges via activation/gradient magnitude correlation across adjacent layers.
For each pair of consecutive ResBlocks, correlates activation magnitudes at layer n-1 with gradient magnitudes at layer n (matching forward/backward information flow), then aggregates across pairs.
- Parameters:
model (GSNN) – Trained model. Must have
checkpoint=False.data (HeteroData-like) – Graph container with
node_names_dictandedge_index_dict.reduction ({'l1', 'l2'}) – Norm used to reduce per-node channel vectors to scalars.
use_pre_norm (bool) – If True (default), score magnitudes from post-
lin_inactivations before the ResBlock normalization layer. This preserves cross-sample magnitude variation whennormis layer, RMS, etc.
- evaluate(layer_agg: Literal['mean', 'max'] = 'mean', exclude_self: bool = True, score: Literal['corr', 'partial'] = 'corr', ridge: float = 1e-08) pandas.DataFrame[source]
Compute correlation scores, p-values, and q-values from accumulated stats.
- Parameters:
layer_agg ({'mean', 'max'}) – How to aggregate adjacent-layer-pair correlations into one score matrix.
exclude_self (bool) – If True, omit diagonal i->i pairs from the output.
score ({'corr', 'partial'}) –
'corr': Pearson correlation between activation magnitudes at layer n-1 and gradient magnitudes at layer n.'partial': partial correlation conditioning on the kept parents of the target inG_partial(function-function edges indata.edge_index_dict). Removes contributions from edges the model already uses.ridge (float) – Tikhonov regularizer for the conditioning-set covariance matrix when
score='partial'. Only used in that mode.
- Returns:
Columns: src_func, dst_func, src_idx, dst_idx, corr, corr_a*_g*, p_value, q_value, has_edge. When
score='partial', additionallyn_cond(size of the conditioning set for that target). For kept edges (has_edge=True) the partial correlation is NaN since the source is in the conditioning set.- Return type:
pandas.DataFrame
- static evaluate_target_ranking(res: pandas.DataFrame, positive_edges: set[tuple[str, str]] | list[tuple[str, str]], score_col: str = 'corr', top_k: tuple[int, ...] = (1, 3, 5)) tuple[pandas.DataFrame, dict[str, float]][source]
Within-target ranking metrics for edge recovery.
For each positive edge
(src, dst)inpositive_edges, rank all candidate sources fordstbyscore_col(descending) and record the rank of the true missing parent. This answers: for each target with a held-out edge, is the correct source ranked highest among competitors for that target?- Parameters:
res (pandas.DataFrame) – Output of
evaluate().positive_edges (set or list of (src, dst)) – Ground-truth edges to recover (typically held-out edges).
srcanddstmust matchsrc_func/dst_funcinres.score_col (str) – Column to rank on (default
'corr').top_k (tuple of int) – Compute
top@khit rate for eachk.
- Returns:
detail (pandas.DataFrame) – One row per positive edge with columns
src_func,dst_func,score,rank(1 = best),n_candidates,reciprocal_rank, andtop@{k}boolean flags.summary (dict) –
n_positives,mrr, andtop@{k}rates.