gsnn.gsnn.optim.MagnitudeEdgeInferer

Tier-0 edge inference via magnitude correlation between node activations and node gradients.

For each function node i (source) and j (target), per adjacent layer pair (n-1, n):

x_i^{n-1} = ||z_i^{n-1}||_p (activation magnitude at layer n-1) y_j^n = ||grad z_j^n L||_p (gradient magnitude at layer n)

Score (score='corr'):

s_n(i -> j) = corr_b(x_i^{n-1}, y_j^n)

Partial-correlation score (score='partial') controls for the kept parents of j in the partial graph G_partial:

S_j = parents_{G_partial}(j) s_n(i -> j) = pcorr_b( x_i^{n-1}, y_j^n | { x_k^{n-1} : k in S_j } )

This removes the contribution of edges the model already uses, mitigating transitive / common-cause confounds that inflate the plain correlation score.

This matches information flow: activations propagate forward n-1 -> n, gradients propagate backward n -> n-1. Pairs are aggregated (mean / max) and tested with Fisher-z p-values + BH-FDR (df = n - 3 - |S_j| for partial correlation).

Requires a trained GSNN with gradient checkpointing disabled (checkpoint=False). Uses ResBlock._last_pre_norm_activation (post-lin_in, pre-norm) so magnitude scores remain meaningful even with layer / RMS normalization.

See docs/notes/edge_inference_notes.md section 4 for design rationale.

Classes

MagnitudeEdgeInferer(model, data[, ...])

Post-hoc inferrer for function -> function edges via activation/gradient magnitude correlation across adjacent layers.

class gsnn.gsnn.optim.MagnitudeEdgeInferer.MagnitudeEdgeInferer(model, data, reduction: Literal['l1', 'l2'] = 'l1', use_pre_norm: bool = True)[source]

Bases: object

Post-hoc inferrer for function -> function edges via activation/gradient magnitude correlation across adjacent layers.

For each pair of consecutive ResBlocks, correlates activation magnitudes at layer n-1 with gradient magnitudes at layer n (matching forward/backward information flow), then aggregates across pairs.

Parameters:
  • model (GSNN) – Trained model. Must have checkpoint=False.

  • data (HeteroData-like) – Graph container with node_names_dict and edge_index_dict.

  • reduction ({'l1', 'l2'}) – Norm used to reduce per-node channel vectors to scalars.

  • use_pre_norm (bool) – If True (default), score magnitudes from post-lin_in activations before the ResBlock normalization layer. This preserves cross-sample magnitude variation when norm is layer, RMS, etc.

evaluate(layer_agg: Literal['mean', 'max'] = 'mean', exclude_self: bool = True, score: Literal['corr', 'partial'] = 'corr', ridge: float = 1e-08) pandas.DataFrame[source]

Compute correlation scores, p-values, and q-values from accumulated stats.

Parameters:
  • layer_agg ({'mean', 'max'}) – How to aggregate adjacent-layer-pair correlations into one score matrix.

  • exclude_self (bool) – If True, omit diagonal i->i pairs from the output.

  • score ({'corr', 'partial'}) – 'corr': Pearson correlation between activation magnitudes at layer n-1 and gradient magnitudes at layer n. 'partial': partial correlation conditioning on the kept parents of the target in G_partial (function-function edges in data.edge_index_dict). Removes contributions from edges the model already uses.

  • ridge (float) – Tikhonov regularizer for the conditioning-set covariance matrix when score='partial'. Only used in that mode.

Returns:

Columns: src_func, dst_func, src_idx, dst_idx, corr, corr_a*_g*, p_value, q_value, has_edge. When score='partial', additionally n_cond (size of the conditioning set for that target). For kept edges (has_edge=True) the partial correlation is NaN since the source is in the conditioning set.

Return type:

pandas.DataFrame

static evaluate_target_ranking(res: pandas.DataFrame, positive_edges: set[tuple[str, str]] | list[tuple[str, str]], score_col: str = 'corr', top_k: tuple[int, ...] = (1, 3, 5)) tuple[pandas.DataFrame, dict[str, float]][source]

Within-target ranking metrics for edge recovery.

For each positive edge (src, dst) in positive_edges, rank all candidate sources for dst by score_col (descending) and record the rank of the true missing parent. This answers: for each target with a held-out edge, is the correct source ranked highest among competitors for that target?

Parameters:
  • res (pandas.DataFrame) – Output of evaluate().

  • positive_edges (set or list of (src, dst)) – Ground-truth edges to recover (typically held-out edges). src and dst must match src_func / dst_func in res.

  • score_col (str) – Column to rank on (default 'corr').

  • top_k (tuple of int) – Compute top@k hit rate for each k.

Returns:

  • detail (pandas.DataFrame) – One row per positive edge with columns src_func, dst_func, score, rank (1 = best), n_candidates, reciprocal_rank, and top@{k} boolean flags.

  • summary (dict) – n_positives, mrr, and top@{k} rates.

fit(dataloader, crit: Optional[torch.nn.Module] = None, device: str = 'cpu', verbose: bool = True) int[source]

Accumulate magnitude statistics over dataloader.

Parameters:
  • dataloader (Iterable) – Yields (x, y) batches.

  • crit (callable, optional) – Loss function. Defaults to MSELoss.

  • device (str) – Device for forward/backward passes.

  • verbose (bool) – Print batch progress.

Returns:

Total number of samples processed.

Return type:

int

reset_stats() None[source]

Clear accumulated streaming statistics.

Memory: O(P * N^2) for both sum_xy and sum_xx (the latter is only required for score='partial' but is always tracked since the per-batch cost is one extra N x N matmul).