gsnn.gsnn.optim.MagnitudeEdgeRegressor

Online Tier-0 edge inference via auxiliary linear regression during GSNN training.

For each adjacent layer pair (n-1, n) and source aggregator k, fit a shared (N, N) weight matrix W so that activation magnitudes at layer n-1 predict gradient magnitudes at layer n:

Y_hat[:, j] = sum_i W[i, j] * Xtilde[:, i]

Magnitudes are taken from ResBlock._last_pre_norm_activation (post-lin_in, pre-norm) and corresponding activation gradients, matching MagnitudeEdgeInferer information flow.

The regressor trains jointly with the GSNN (detached features, separate optimizer). Held-out validation edges drive best-checkpoint selection, mitigating gradient absorption at equilibrium.

See docs/notes/edge_inference_notes.md section 4 and tutorial 14.

Classes

MagnitudeEdgeInferer(model, data[, ...])

Post-hoc inferrer for function -> function edges via activation/gradient magnitude correlation across adjacent layers.

MagnitudeEdgeRegressor(*args, **kwargs)

Online auxiliary linear regressor for function -> function edge inference.

class gsnn.gsnn.optim.MagnitudeEdgeRegressor.MagnitudeEdgeRegressor(*args: Any, **kwargs: Any)[source]

Bases: Module

Online auxiliary linear regressor for function -> function edge inference.

Learns a single shared weight matrix W of shape (N, N) during GSNN training. Source activations (layer n-1) predict target gradient magnitudes (layer n) across adjacent ResBlock pairs and multiple source aggregators.

Parameters:
  • model (GSNN) – Model being trained. Must have checkpoint=False.

  • data (HeteroData-like) – Graph container with node_names_dict and edge_index_dict.

  • aggregators (sequence of str) – Source-side channel reductions: 'sum', 'max', 'mean', 'l2'. Target gradients always use L1 (sum of absolute values).

  • use_pre_norm (bool) – If True (default), use post-lin_in pre-norm activations.

  • standardize (bool) – If True (default), EMA z-score features per (pair, aggregator).

  • lr (float) – AdamW hyperparameters for W only.

  • weight_decay (float) – AdamW hyperparameters for W only.

  • ridge (float) – Additional L2 penalty on W beyond weight_decay.

  • score_mode ({'abs', 'relu', 'signed'}) – How to convert W entries into edge scores for ranking.

  • ema_momentum (float) – Momentum for running mean/variance updates during standardization.

arm_retained_grads() None[source]

Call after model(x) and before loss.backward() to retain grads.

aux_step() dict[str, float][source]

Build features from cached activations/grads, update W, return metrics.

Must be called after loss.backward() so activation gradients exist. Features are detached — no gradient flows into the GSNN from this step.

evaluate(*, exclude_self: bool = True) pandas.DataFrame[source]

Build edge score DataFrame from current W.

Returns columns compatible with MagnitudeEdgeInferer.evaluate: src_func, dst_func, src_idx, dst_idx, score, has_edge, p_value, q_value.

evaluate_against(positive_edges: set[tuple[str, str]] | list[tuple[str, str]], *, top_k: tuple[int, ...] = (1, 3, 5)) dict[str, float][source]

Score held-out edges against non-edges using current W.

Returns global ROC-AUC plus within-target MRR and top@k rates.

static evaluate_target_ranking(res: pandas.DataFrame, positive_edges: set[tuple[str, str]] | list[tuple[str, str]], score_col: str = 'score', top_k: tuple[int, ...] = (1, 3, 5)) tuple[pandas.DataFrame, dict[str, float]][source]

Delegate to MagnitudeEdgeInferer.evaluate_target_ranking.

load_best() None[source]

Restore weights from the best validation checkpoint.

maybe_save_best(metric: float, mode: str = 'max') bool[source]

Save state_dict if metric improves over the previous best.

pre_forward() None[source]

Enable activation caching on all ResBlocks before model(x).

score_matrix() numpy.ndarray[source]

Return (N, N) edge score matrix derived from W.