Source code for gsnn.optim.MagnitudeEdgeInferer

'''
Tier-0 edge inference via magnitude correlation between node activations and
node gradients.

For each function node i (source) and j (target), per adjacent layer pair (n-1, n):

    x_i^{n-1} = ||z_i^{n-1}||_p          (activation magnitude at layer n-1)
    y_j^n     = ||grad z_j^n L||_p        (gradient magnitude at layer n)

Score (``score='corr'``):
    s_n(i -> j) = corr_b(x_i^{n-1}, y_j^n)

Partial-correlation score (``score='partial'``) controls for the kept parents
of ``j`` in the partial graph ``G_partial``:

    S_j = parents_{G_partial}(j)
    s_n(i -> j) = pcorr_b( x_i^{n-1},  y_j^n  |  { x_k^{n-1} : k in S_j } )

This removes the contribution of edges the model already uses, mitigating
transitive / common-cause confounds that inflate the plain correlation score.

This matches information flow: activations propagate forward n-1 -> n, gradients
propagate backward n -> n-1. Pairs are aggregated (mean / max) and tested with
Fisher-z p-values + BH-FDR (``df = n - 3 - |S_j|`` for partial correlation).

Requires a trained GSNN with gradient checkpointing disabled (``checkpoint=False``).
Uses ``ResBlock._last_pre_norm_activation`` (post-``lin_in``, pre-norm) so magnitude
scores remain meaningful even with layer / RMS normalization.

See ``docs/notes/edge_inference_notes.md`` section 4 for design rationale.
'''

from __future__ import annotations

import math
from typing import Literal, Optional

import numpy as np
import pandas as pd
import scipy.special
import torch


[docs]class MagnitudeEdgeInferer: ''' Post-hoc inferrer for function -> function edges via activation/gradient magnitude correlation across adjacent layers. For each pair of consecutive ResBlocks, correlates activation magnitudes at layer n-1 with gradient magnitudes at layer n (matching forward/backward information flow), then aggregates across pairs. Parameters ---------- model : GSNN Trained model. Must have ``checkpoint=False``. data : HeteroData-like Graph container with ``node_names_dict`` and ``edge_index_dict``. reduction : {'l1', 'l2'} Norm used to reduce per-node channel vectors to scalars. use_pre_norm : bool If True (default), score magnitudes from post-``lin_in`` activations before the ResBlock normalization layer. This preserves cross-sample magnitude variation when ``norm`` is layer, RMS, etc. ''' def __init__( self, model, data, reduction: Literal['l1', 'l2'] = 'l1', use_pre_norm: bool = True, ): if getattr(model, 'checkpoint', False): raise ValueError( 'MagnitudeEdgeInferer requires model.checkpoint=False; ' 'gradient checkpointing recomputes activations during backward.' ) self.model = model self.data = data self.reduction = reduction self.use_pre_norm = use_pre_norm if not use_pre_norm: norm = getattr(model, 'norm', None) if norm not in ('batch', 'groupbatch', 'none'): raise ValueError( 'MagnitudeEdgeInferer with use_pre_norm=False requires ' 'model.norm to be "batch", "groupbatch", or "none"; ' f'got {norm!r}. Layer/RMS norms invalidate post-norm magnitude ' 'scores — use use_pre_norm=True (default) instead.' ) self.function_nodes = list(data.node_names_dict['function']) self.N = len(self.function_nodes) self.L = len(model.ResBlocks) self.n_pairs = max(self.L - 1, 0) block0 = model.ResBlocks[0] self.channel_groups = block0.channel_groups.detach().cpu() function_mask = ~(model.input_node_mask | model.output_node_mask) self.function_homo_idxs = function_mask.nonzero(as_tuple=True)[0].cpu().tolist() if len(self.function_homo_idxs) != self.N: raise ValueError( f'Expected {self.N} function nodes in homogeneous graph, ' f'got {len(self.function_homo_idxs)}.' ) self._channel_ixs_by_fn = [] for homo_idx in self.function_homo_idxs: ixs = (self.channel_groups == homo_idx).nonzero(as_tuple=True)[0] if ixs.numel() == 0: raise ValueError( f'Function node homo index {homo_idx} has no channels in channel_groups.' ) self._channel_ixs_by_fn.append(ixs) edge_index = data.edge_index_dict['function', 'to', 'function'] if hasattr(edge_index, 'detach'): edge_arr = edge_index.T.detach().cpu().numpy() else: edge_arr = np.asarray(edge_index.T) self.edges = { (self.function_nodes[i], self.function_nodes[j]) for i, j in edge_arr } # Kept parents of each target j (function-function edges in G_partial). # Used as the conditioning set for partial-correlation scoring. self.parents_by_target: list[np.ndarray] = [ np.zeros(0, dtype=np.int64) for _ in range(self.N) ] if edge_arr.size > 0: buf: list[list[int]] = [[] for _ in range(self.N)] for i, j in edge_arr: if 0 <= i < self.N and 0 <= j < self.N and i != j: buf[int(j)].append(int(i)) self.parents_by_target = [ np.asarray(sorted(set(p)), dtype=np.int64) for p in buf ] self.reset_stats()
[docs] def reset_stats(self) -> None: '''Clear accumulated streaming statistics. Memory: O(P * N^2) for both ``sum_xy`` and ``sum_xx`` (the latter is only required for ``score='partial'`` but is always tracked since the per-batch cost is one extra ``N x N`` matmul). ''' P, N = self.n_pairs, self.N self.n: int = 0 self.sum_x = np.zeros((P, N), dtype=np.float64) self.sum_y = np.zeros((P, N), dtype=np.float64) self.sum_x2 = np.zeros((P, N), dtype=np.float64) self.sum_y2 = np.zeros((P, N), dtype=np.float64) self.sum_xy = np.zeros((P, N, N), dtype=np.float64) self.sum_xx = np.zeros((P, N, N), dtype=np.float64)
def _reduce_channels(self, tensor: torch.Tensor) -> torch.Tensor: ''' Reduce (B, C_total) activations/gradients to (B, N) function-node magnitudes. ''' if tensor.dim() == 3 and tensor.size(-1) == 1: tensor = tensor.squeeze(-1) if tensor.dim() != 2: raise ValueError( f'Expected activation/grad shape (B, C) or (B, C, 1), got {tuple(tensor.shape)}' ) B = tensor.size(0) out = torch.empty(B, self.N, device=tensor.device, dtype=tensor.dtype) for fn_idx, ch_ixs in enumerate(self._channel_ixs_by_fn): ch_ixs = ch_ixs.to(tensor.device) vals = tensor[:, ch_ixs] if self.reduction == 'l1': out[:, fn_idx] = vals.abs().sum(dim=-1) elif self.reduction == 'l2': out[:, fn_idx] = vals.pow(2).sum(dim=-1).sqrt() else: raise ValueError(f'Unknown reduction: {self.reduction}') return out def _per_batch_magnitudes( self, x: torch.Tensor, y: torch.Tensor, crit: torch.nn.Module, ) -> tuple[torch.Tensor, torch.Tensor]: ''' Forward + backward pass; return activation and gradient magnitudes. Returns ------- x_mag, y_mag : Tensor Shapes (L, B, N) for activation magnitudes and gradient magnitudes. ''' model = self.model was_training = model.training model.eval() for mod in model.ResBlocks: mod._store_activations = True act_attr = '_last_pre_norm_activation' if self.use_pre_norm else '_last_activation' try: yhat = model(x) loss = crit(yhat, y) acts = [getattr(mod, act_attr, None) for mod in model.ResBlocks] for act in acts: if act is None: raise RuntimeError( f'ResBlock.{act_attr} is None; ensure _store_activations=True.' ) act.retain_grad() model.zero_grad(set_to_none=True) loss.backward() grads = [] for act in acts: if act.grad is None: raise RuntimeError( 'Activation gradient is None after backward; ' 'check that checkpoint=False and retain_grad() was called.' ) grads.append(act.grad) x_mag = torch.stack([self._reduce_channels(a) for a in acts], dim=0) y_mag = torch.stack([self._reduce_channels(g) for g in grads], dim=0) finally: for mod in model.ResBlocks: mod._store_activations = False for attr in ('_last_activation', '_last_pre_norm_activation'): if hasattr(mod, attr): delattr(mod, attr) model.train(was_training) return x_mag, y_mag def _update_stats(self, x_mag: torch.Tensor, y_mag: torch.Tensor) -> None: '''Update streaming sufficient statistics from one batch. Correlates activation magnitudes at layer n-1 with gradient magnitudes at layer n for each adjacent pair. ''' x_np = x_mag.detach().cpu().numpy() y_np = y_mag.detach().cpu().numpy() B = x_np.shape[1] self.n += B for p in range(self.n_pairs): act_l = x_np[p] grad_l = y_np[p + 1] self.sum_x[p] += act_l.sum(axis=0) self.sum_y[p] += grad_l.sum(axis=0) self.sum_x2[p] += (act_l ** 2).sum(axis=0) self.sum_y2[p] += (grad_l ** 2).sum(axis=0) self.sum_xy[p] += act_l.T @ grad_l self.sum_xx[p] += act_l.T @ act_l
[docs] def fit( self, dataloader, crit: Optional[torch.nn.Module] = None, device: str = 'cpu', verbose: bool = True, ) -> int: ''' Accumulate magnitude statistics over ``dataloader``. Parameters ---------- dataloader : Iterable Yields ``(x, y)`` batches. crit : callable, optional Loss function. Defaults to ``MSELoss``. device : str Device for forward/backward passes. verbose : bool Print batch progress. Returns ------- int Total number of samples processed. ''' if crit is None: crit = torch.nn.MSELoss() self.reset_stats() self.model.to(device) n_batches = len(dataloader) if hasattr(dataloader, '__len__') else None for bi, (x, y) in enumerate(dataloader): x = x.to(device) y = y.to(device) x_mag, y_mag = self._per_batch_magnitudes(x, y, crit) self._update_stats(x_mag, y_mag) if verbose and n_batches is not None: print(f'[batch {bi + 1}/{n_batches}] n={self.n}', end='\r') if verbose: print() return self.n
def _check_ready(self) -> None: if self.n_pairs == 0: raise RuntimeError( f'Need at least 2 ResBlocks to form activation/gradient pairs; got L={self.L}.' ) if self.n < 3: raise RuntimeError( f'Need at least 3 samples to compute correlations; got n={self.n}.' ) def _compute_correlations(self) -> np.ndarray: '''Return per-pair correlation tensor of shape (L-1, N, N). Pair p compares activation magnitudes at layer p with gradient magnitudes at layer p+1. ''' self._check_ready() n = float(self.n) eps = 1e-12 mean_x = self.sum_x / n mean_y = self.sum_y / n var_x = np.maximum(self.sum_x2 / n - mean_x ** 2, 0.0) var_y = np.maximum(self.sum_y2 / n - mean_y ** 2, 0.0) corr = np.full((self.n_pairs, self.N, self.N), np.nan, dtype=np.float64) for p in range(self.n_pairs): cov_xy = self.sum_xy[p] / n - mean_x[p][:, None] * mean_y[p][None, :] denom = np.sqrt(var_x[p][:, None] * var_y[p][None, :]) with np.errstate(divide='ignore', invalid='ignore'): corr_p = cov_xy / np.maximum(denom, eps) corr_p[(var_x[p][:, None] <= eps) | (var_y[p][None, :] <= eps)] = np.nan corr[p] = np.clip(corr_p, -1.0, 1.0) return corr def _compute_partial_correlations(self, ridge: float = 1e-8) -> np.ndarray: '''Per-pair partial correlation tensor of shape (L-1, N, N). Pair p partials activation magnitudes at layer p (source) and gradient magnitudes at layer p+1 (target) on the kept parents of the target in ``G_partial``: S_j = parents_{G_partial}(j) pcorr(x_i, y_j | x_S_j) Computed via Schur complement on the joint covariance of (x_S, x_i, y_j) using the streamed sufficient statistics. Entries with ``i in S_j`` are set to NaN (partial correlation undefined / trivially 0; these are kept edges, not candidates). Targets with empty ``S_j`` reduce to plain Pearson correlation. Parameters ---------- ridge : float Tikhonov regularizer added to the diagonal of ``Sigma_SS`` for numerical stability when conditioning sets are nearly collinear. ''' self._check_ready() n = float(self.n) eps = 1e-12 mean_x = self.sum_x / n mean_y = self.sum_y / n var_x_diag = np.maximum(self.sum_x2 / n - mean_x ** 2, 0.0) var_y_diag = np.maximum(self.sum_y2 / n - mean_y ** 2, 0.0) pcorr = np.full((self.n_pairs, self.N, self.N), np.nan, dtype=np.float64) for p in range(self.n_pairs): Sigma_xx = self.sum_xx[p] / n - np.outer(mean_x[p], mean_x[p]) Sigma_xy = self.sum_xy[p] / n - np.outer(mean_x[p], mean_y[p]) var_x = var_x_diag[p] var_y = var_y_diag[p] for j in range(self.N): S = self.parents_by_target[j] vy_j = var_y[j] if S.size == 0: if vy_j <= eps: continue denom = np.sqrt(np.maximum(var_x, 0.0) * vy_j) with np.errstate(divide='ignore', invalid='ignore'): row = Sigma_xy[:, j] / np.maximum(denom, eps) row[var_x <= eps] = np.nan pcorr[p, :, j] = np.clip(row, -1.0, 1.0) continue Sigma_SS = Sigma_xx[np.ix_(S, S)].copy() if ridge > 0: Sigma_SS.flat[:: Sigma_SS.shape[0] + 1] += ridge Sigma_Sj = Sigma_xy[S, j] try: alpha_j = np.linalg.solve(Sigma_SS, Sigma_Sj) except np.linalg.LinAlgError: continue var_y_partial = vy_j - Sigma_Sj @ alpha_j if not np.isfinite(var_y_partial) or var_y_partial <= eps: continue Sigma_iS = Sigma_xx[:, S] try: M = np.linalg.solve(Sigma_SS, Sigma_iS.T) except np.linalg.LinAlgError: continue var_x_partial = var_x - np.einsum('ik,ki->i', Sigma_iS, M) num = Sigma_xy[:, j] - Sigma_iS @ alpha_j denom = np.sqrt(np.maximum(var_x_partial, 0.0) * var_y_partial) with np.errstate(divide='ignore', invalid='ignore'): row = num / np.maximum(denom, eps) bad = (var_x_partial <= eps) | ~np.isfinite(row) row[bad] = np.nan # i in S_j: x_i is in the conditioning set, partial corr undefined. row[S] = np.nan # i == j: covered by exclude_self in evaluate, but mark explicitly. row[j] = np.nan pcorr[p, :, j] = np.clip(row, -1.0, 1.0) return pcorr @staticmethod def _aggregate_pairs(corr: np.ndarray, layer_agg: str) -> np.ndarray: if layer_agg == 'mean': with np.errstate(all='ignore'): return np.nanmean(corr, axis=0) if layer_agg == 'max': abs_corr = np.abs(corr) all_nan = np.all(np.isnan(abs_corr), axis=0) # nanargmax raises on all-NaN slices; use -inf sentinel instead. abs_filled = abs_corr.copy() abs_filled[np.isnan(abs_filled)] = -np.inf idx = np.argmax(abs_filled, axis=0) P, N, _ = corr.shape rows = np.arange(N)[:, None] cols = np.arange(N)[None, :] out = corr[idx, rows, cols] out[all_nan] = np.nan return out raise ValueError(f"Unknown layer_agg: {layer_agg}. Use 'mean' or 'max'.") @staticmethod def _fisher_pvalues( corr: np.ndarray, n: int, cond_size: int | np.ndarray = 0, ) -> np.ndarray: '''One-sided p-values testing corr > 0 (H0: r <= 0). For partial correlation, ``cond_size`` is the number of conditioning variables (df penalty); for plain Pearson, leave at 0. ``cond_size`` may be a scalar or per-element array broadcastable against ``corr``. ''' df = (n - 3) - np.asarray(cond_size, dtype=np.float64) safe_df = np.maximum(df, 1.0) z = np.arctanh(np.clip(corr, -0.9999, 0.9999)) * np.sqrt(safe_df) p = 0.5 * (1.0 - scipy.special.erf(z / math.sqrt(2.0))) p = np.where(np.isnan(corr), 1.0, p) p = np.where(df < 2, 1.0, p) return p @staticmethod def _bh_fdr(pvals: np.ndarray) -> np.ndarray: pvals = pvals.astype(float) m = len(pvals) order = np.argsort(pvals) ranked = pvals[order] bh = ranked * m / (np.arange(1, m + 1)) bh = np.minimum.accumulate(bh[::-1])[::-1] bh = np.clip(bh, 0.0, 1.0) q_values = np.empty_like(bh) q_values[order] = bh return q_values
[docs] def evaluate( self, layer_agg: Literal['mean', 'max'] = 'mean', exclude_self: bool = True, score: Literal['corr', 'partial'] = 'corr', ridge: float = 1e-8, ) -> pd.DataFrame: ''' Compute correlation scores, p-values, and q-values from accumulated stats. Parameters ---------- layer_agg : {'mean', 'max'} How to aggregate adjacent-layer-pair correlations into one score matrix. exclude_self : bool If True, omit diagonal i->i pairs from the output. score : {'corr', 'partial'} ``'corr'``: Pearson correlation between activation magnitudes at layer n-1 and gradient magnitudes at layer n. ``'partial'``: partial correlation conditioning on the kept parents of the target in ``G_partial`` (function-function edges in ``data.edge_index_dict``). Removes contributions from edges the model already uses. ridge : float Tikhonov regularizer for the conditioning-set covariance matrix when ``score='partial'``. Only used in that mode. Returns ------- pandas.DataFrame Columns: src_func, dst_func, src_idx, dst_idx, corr, corr_a*_g*, p_value, q_value, has_edge. When ``score='partial'``, additionally ``n_cond`` (size of the conditioning set for that target). For kept edges (``has_edge=True``) the partial correlation is NaN since the source is in the conditioning set. ''' if score == 'corr': corr_pairs = self._compute_correlations() elif score == 'partial': corr_pairs = self._compute_partial_correlations(ridge=ridge) else: raise ValueError(f"Unknown score: {score!r}. Use 'corr' or 'partial'.") corr = self._aggregate_pairs(corr_pairs, layer_agg) cond_sizes = np.asarray( [self.parents_by_target[j].size for j in range(self.N)], dtype=np.int64, ) rows = [] for i, src in enumerate(self.function_nodes): for j, dst in enumerate(self.function_nodes): if exclude_self and i == j: continue row = { 'src_func': src, 'dst_func': dst, 'src_idx': i, 'dst_idx': j, 'corr': corr[i, j], 'has_edge': (src, dst) in self.edges, } if score == 'partial': row['n_cond'] = int(cond_sizes[j]) for p in range(self.n_pairs): row[f'corr_a{p}_g{p + 1}'] = corr_pairs[p, i, j] rows.append(row) res = pd.DataFrame(rows) cs = cond_sizes[res['dst_idx'].to_numpy()] if score == 'partial' else 0 pvals = self._fisher_pvalues(res['corr'].to_numpy(), self.n, cond_size=cs) res['p_value'] = pvals res['q_value'] = self._bh_fdr(pvals) res = res.sort_values('corr', ascending=False, na_position='last').reset_index(drop=True) return res
[docs] @staticmethod def evaluate_target_ranking( res: pd.DataFrame, positive_edges: set[tuple[str, str]] | list[tuple[str, str]], score_col: str = 'corr', top_k: tuple[int, ...] = (1, 3, 5), ) -> tuple[pd.DataFrame, dict[str, float]]: '''Within-target ranking metrics for edge recovery. For each positive edge ``(src, dst)`` in ``positive_edges``, rank all candidate sources for ``dst`` by ``score_col`` (descending) and record the rank of the true missing parent. This answers: *for each target with a held-out edge, is the correct source ranked highest among competitors for that target?* Parameters ---------- res : pandas.DataFrame Output of :meth:`evaluate`. positive_edges : set or list of (src, dst) Ground-truth edges to recover (typically held-out edges). ``src`` and ``dst`` must match ``src_func`` / ``dst_func`` in ``res``. score_col : str Column to rank on (default ``'corr'``). top_k : tuple of int Compute ``top@k`` hit rate for each ``k``. Returns ------- detail : pandas.DataFrame One row per positive edge with columns ``src_func``, ``dst_func``, ``score``, ``rank`` (1 = best), ``n_candidates``, ``reciprocal_rank``, and ``top@{k}`` boolean flags. summary : dict ``n_positives``, ``mrr``, and ``top@{k}`` rates. ''' pos = {(s, d) for s, d in positive_edges} if not pos: raise ValueError('positive_edges is empty.') ranked = res.dropna(subset=[score_col]).copy() ranked['rank'] = ( ranked.groupby('dst_func', sort=False)[score_col] .rank(ascending=False, method='first') .astype(int) ) ranked['n_candidates'] = ranked.groupby('dst_func', sort=False)['dst_func'].transform('count') rows = [] for src, dst in sorted(pos): sub = ranked[(ranked['src_func'] == src) & (ranked['dst_func'] == dst)] if sub.empty: rows.append({ 'src_func': src, 'dst_func': dst, 'score': np.nan, 'rank': np.nan, 'n_candidates': int((ranked['dst_func'] == dst).sum()), 'reciprocal_rank': 0.0, **{f'top@{k}': False for k in top_k}, }) continue row = sub.iloc[0] rank = int(row['rank']) rr = 1.0 / rank rows.append({ 'src_func': src, 'dst_func': dst, 'score': row[score_col], 'rank': rank, 'n_candidates': int(row['n_candidates']), 'reciprocal_rank': rr, **{f'top@{k}': rank <= k for k in top_k}, }) detail = pd.DataFrame(rows) n = len(detail) summary: dict[str, float] = {'n_positives': float(n), 'mrr': detail['reciprocal_rank'].mean()} for k in top_k: col = f'top@{k}' summary[col] = detail[col].mean() if n else float('nan') return detail, summary