'''
Online Tier-0 edge inference via auxiliary linear regression during GSNN training.
For each adjacent layer pair (n-1, n) and source aggregator k, fit a shared
(N, N) weight matrix W so that activation magnitudes at layer n-1 predict
gradient magnitudes at layer n:
Y_hat[:, j] = sum_i W[i, j] * Xtilde[:, i]
Magnitudes are taken from ``ResBlock._last_pre_norm_activation`` (post-``lin_in``,
pre-norm) and corresponding activation gradients, matching
``MagnitudeEdgeInferer`` information flow.
The regressor trains jointly with the GSNN (detached features, separate optimizer).
Held-out validation edges drive best-checkpoint selection, mitigating gradient
absorption at equilibrium.
See ``docs/notes/edge_inference_notes.md`` section 4 and tutorial 14.
'''
from __future__ import annotations
import copy
import math
from typing import Literal, Optional, Sequence
import numpy as np
import pandas as pd
import scipy.special
import torch
import torch.nn as nn
from gsnn.optim.MagnitudeEdgeInferer import MagnitudeEdgeInferer
[docs]class MagnitudeEdgeRegressor(nn.Module):
'''
Online auxiliary linear regressor for function -> function edge inference.
Learns a single shared weight matrix ``W`` of shape ``(N, N)`` during GSNN
training. Source activations (layer n-1) predict target gradient magnitudes
(layer n) across adjacent ResBlock pairs and multiple source aggregators.
Parameters
----------
model : GSNN
Model being trained. Must have ``checkpoint=False``.
data : HeteroData-like
Graph container with ``node_names_dict`` and ``edge_index_dict``.
aggregators : sequence of str
Source-side channel reductions: ``'sum'``, ``'max'``, ``'mean'``, ``'l2'``.
Target gradients always use L1 (sum of absolute values).
use_pre_norm : bool
If True (default), use post-``lin_in`` pre-norm activations.
standardize : bool
If True (default), EMA z-score features per (pair, aggregator).
lr, weight_decay : float
AdamW hyperparameters for ``W`` only.
ridge : float
Additional L2 penalty on ``W`` beyond ``weight_decay``.
score_mode : {'abs', 'relu', 'signed'}
How to convert ``W`` entries into edge scores for ranking.
ema_momentum : float
Momentum for running mean/variance updates during standardization.
'''
_AGG_FNS = {
'sum': lambda v: v.abs().sum(dim=-1),
'max': lambda v: v.abs().max(dim=-1).values,
'mean': lambda v: v.abs().mean(dim=-1),
'l2': lambda v: v.pow(2).sum(dim=-1).sqrt(),
}
def __init__(
self,
model,
data,
*,
aggregators: Sequence[str] = ('sum', 'max'),
use_pre_norm: bool = True,
standardize: bool = True,
lr: float = 1e-2,
weight_decay: float = 1e-4,
dropout: float = 0.0,
ridge: float = 0.0,
score_mode: Literal['abs', 'relu', 'signed'] = 'abs',
ema_momentum: float = 0.1,
):
super().__init__()
if getattr(model, 'checkpoint', False):
raise ValueError(
'MagnitudeEdgeRegressor requires model.checkpoint=False; '
'gradient checkpointing recomputes activations during backward.'
)
if not use_pre_norm:
norm = getattr(model, 'norm', None)
if norm not in ('batch', 'groupbatch', 'none'):
raise ValueError(
'MagnitudeEdgeRegressor with use_pre_norm=False requires '
'model.norm to be "batch", "groupbatch", or "none"; '
f'got {norm!r}. Use use_pre_norm=True (default) instead.'
)
for agg in aggregators:
if agg not in self._AGG_FNS:
raise ValueError(
f'Unknown aggregator {agg!r}. Use one of {tuple(self._AGG_FNS)}.'
)
self.__dict__['model'] = model
self.data = data
self.aggregators = tuple(aggregators)
self.use_pre_norm = use_pre_norm
self.standardize = standardize
self.ridge = float(ridge)
self.score_mode = score_mode
self.ema_momentum = float(ema_momentum)
self.dropout = torch.nn.Dropout(dropout)
self.function_nodes = list(data.node_names_dict['function'])
self.N = len(self.function_nodes)
self.L = len(model.ResBlocks)
self.n_pairs = max(self.L - 1, 0)
self.n_aggs = len(self.aggregators)
block0 = model.ResBlocks[0]
self.channel_groups = block0.channel_groups.detach().cpu()
function_mask = ~(model.input_node_mask | model.output_node_mask)
self.function_homo_idxs = function_mask.nonzero(as_tuple=True)[0].cpu().tolist()
if len(self.function_homo_idxs) != self.N:
raise ValueError(
f'Expected {self.N} function nodes in homogeneous graph, '
f'got {len(self.function_homo_idxs)}.'
)
self._channel_ixs_by_fn = []
for homo_idx in self.function_homo_idxs:
ixs = (self.channel_groups == homo_idx).nonzero(as_tuple=True)[0]
if ixs.numel() == 0:
raise ValueError(
f'Function node homo index {homo_idx} has no channels in channel_groups.'
)
self._channel_ixs_by_fn.append(ixs)
edge_index = data.edge_index_dict['function', 'to', 'function']
if hasattr(edge_index, 'detach'):
edge_arr = edge_index.T.detach().cpu().numpy()
else:
edge_arr = np.asarray(edge_index.T)
self.edges = {
(self.function_nodes[i], self.function_nodes[j])
for i, j in edge_arr
}
# Shared (N, N) weight matrix; columns predict each target gradient magnitude.
self.W = nn.Parameter(torch.zeros(self.N, self.N))
nn.init.xavier_uniform_(self.W, gain=0.1)
self.optim = torch.optim.AdamW(
[self.W], lr=lr, weight_decay=weight_decay,
)
P, K, N = self.n_pairs, self.n_aggs, self.N
self.register_buffer('running_mean_x', torch.zeros(P, K, N))
self.register_buffer('running_var_x', torch.ones(P, K, N))
self.register_buffer('running_mean_y', torch.zeros(P, N))
self.register_buffer('running_var_y', torch.ones(P, N))
self.register_buffer('n_batches_tracked', torch.tensor(0, dtype=torch.long))
# Streaming Gram for p-value computation in evaluate().
self.register_buffer('gram', torch.zeros(N, N))
self.register_buffer('n_samples', torch.tensor(0, dtype=torch.long))
self.register_buffer('residual_sum_sq', torch.tensor(0.0))
self._act_attr = '_last_pre_norm_activation' if use_pre_norm else '_last_activation'
self._acts: list[torch.Tensor] | None = None
self._best_state: dict | None = None
self._best_metric: float | None = None
[docs] def pre_forward(self) -> None:
'''Enable activation caching on all ResBlocks before ``model(x)``.'''
for mod in self.model.ResBlocks:
mod._store_activations = True
self._acts = None
[docs] def arm_retained_grads(self) -> None:
'''Call after ``model(x)`` and before ``loss.backward()`` to retain grads.'''
acts = []
for mod in self.model.ResBlocks:
act = getattr(mod, self._act_attr, None)
if act is None:
raise RuntimeError(
f'ResBlock.{self._act_attr} is None; call pre_forward() first.'
)
act.retain_grad()
acts.append(act)
self._acts = acts
def _cleanup_hooks(self) -> None:
for mod in self.model.ResBlocks:
mod._store_activations = False
for attr in ('_last_activation', '_last_pre_norm_activation'):
if hasattr(mod, attr):
delattr(mod, attr)
self._acts = None
def _reduce_channels(
self,
tensor: torch.Tensor,
aggregator: str,
) -> torch.Tensor:
'''Reduce (B, C_total) to (B, N) function-node magnitudes.'''
if tensor.dim() == 3 and tensor.size(-1) == 1:
tensor = tensor.squeeze(-1)
if tensor.dim() != 2:
raise ValueError(
f'Expected activation/grad shape (B, C) or (B, C, 1), got {tuple(tensor.shape)}'
)
B = tensor.size(0)
fn = self._AGG_FNS[aggregator]
out = torch.empty(B, self.N, device=tensor.device, dtype=tensor.dtype)
for fn_idx, ch_ixs in enumerate(self._channel_ixs_by_fn):
ch_ixs = ch_ixs.to(tensor.device)
out[:, fn_idx] = fn(tensor[:, ch_ixs])
return out
def _standardize(
self,
x: torch.Tensor,
y: torch.Tensor,
pair_idx: int,
agg_idx: int,
training: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.standardize:
return x, y
eps = 1e-8
if training:
batch_mean_x = x.mean(dim=0)
batch_var_x = x.var(dim=0, unbiased=False)
batch_mean_y = y.mean(dim=0)
batch_var_y = y.var(dim=0, unbiased=False)
m = self.ema_momentum
self.running_mean_x[pair_idx, agg_idx].lerp_(batch_mean_x, m)
self.running_var_x[pair_idx, agg_idx].lerp_(batch_var_x, m)
self.running_mean_y[pair_idx].lerp_(batch_mean_y, m)
self.running_var_y[pair_idx].lerp_(batch_var_y, m)
self.n_batches_tracked += 1
mean_x, var_x = batch_mean_x, batch_var_x
mean_y, var_y = batch_mean_y, batch_var_y
else:
mean_x = self.running_mean_x[pair_idx, agg_idx]
var_x = self.running_var_x[pair_idx, agg_idx]
mean_y = self.running_mean_y[pair_idx]
var_y = self.running_var_y[pair_idx]
x_std = (x - mean_x) / torch.sqrt(var_x + eps)
y_std = (y - mean_y) / torch.sqrt(var_y + eps)
return x_std, y_std
def _update_gram(self, x_std: torch.Tensor, y_hat: torch.Tensor, y_std: torch.Tensor) -> None:
'''Accumulate streaming Gram matrix and residual variance for p-values.'''
B = x_std.size(0)
with torch.no_grad():
self.gram.add_(x_std.T @ x_std)
self.n_samples += B
self.residual_sum_sq += ((y_hat - y_std) ** 2).sum().item()
[docs] def aux_step(self) -> dict[str, float]:
'''
Build features from cached activations/grads, update ``W``, return metrics.
Must be called after ``loss.backward()`` so activation gradients exist.
Features are detached — no gradient flows into the GSNN from this step.
'''
if self._acts is None:
raise RuntimeError(
'No cached activations; call pre_forward(), model(x), '
'arm_retained_grads(), then loss.backward() before aux_step().'
)
try:
grads = []
for act in self._acts:
if act.grad is None:
raise RuntimeError(
'Activation gradient is None after backward; '
'call arm_retained_grads() before loss.backward().'
)
grads.append(act.grad)
y_layers = [self._reduce_channels(g.detach(), 'sum') for g in grads]
total_loss = torch.tensor(0.0, device=self.W.device)
n_terms = 0
gram_updates: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []
for p in range(self.n_pairs):
y = y_layers[p + 1]
for k, agg in enumerate(self.aggregators):
x = self._reduce_channels(self._acts[p].detach(), agg)
x_std, y_std = self._standardize(
x, y, pair_idx=p, agg_idx=k, training=self.training,
)
y_hat = x_std @ self.dropout(self.W)
mse = ((y_hat - y_std) ** 2).mean()
total_loss = total_loss + mse
n_terms += 1
gram_updates.append((x_std.detach(), y_hat.detach(), y_std.detach()))
if self.ridge > 0:
total_loss = total_loss + self.ridge * (self.W ** 2).sum()
self.optim.zero_grad()
total_loss.backward()
self.optim.step()
with torch.no_grad():
for x_std, y_hat, y_std in gram_updates:
self._update_gram(x_std, y_hat, y_std)
avg_loss = (total_loss / max(n_terms, 1)).item()
return {'aux_loss': avg_loss, 'n_terms': float(n_terms)}
finally:
self._cleanup_hooks()
[docs] def score_matrix(self) -> np.ndarray:
"""Return (N, N) edge score matrix derived from ``W``."""
w = self.W.detach().cpu().numpy()
if self.score_mode == 'abs':
return np.abs(w)
if self.score_mode == 'relu':
return np.maximum(w, 0.0)
if self.score_mode == 'signed':
return w
raise ValueError(f'Unknown score_mode: {self.score_mode!r}')
@staticmethod
def _bh_fdr(pvals: np.ndarray) -> np.ndarray:
pvals = pvals.astype(float)
m = len(pvals)
order = np.argsort(pvals)
ranked = pvals[order]
bh = ranked * m / (np.arange(1, m + 1))
bh = np.minimum.accumulate(bh[::-1])[::-1]
bh = np.clip(bh, 0.0, 1.0)
q_values = np.empty_like(bh)
q_values[order] = bh
return q_values
def _coefficient_pvalues(self) -> np.ndarray:
'''One-sided p-values testing W_ij > 0 via Gram-based SE approximation.'''
N = self.N
pvals = np.ones((N, N), dtype=np.float64)
n = int(self.n_samples.item())
if n < N + 2:
return pvals
w = self.W.detach().cpu().numpy()
gram = self.gram.detach().cpu().numpy() / max(n, 1)
ridge = self.ridge + 1e-8
try:
gram_inv = np.linalg.inv(gram + ridge * np.eye(N))
except np.linalg.LinAlgError:
return pvals
rss = float(self.residual_sum_sq.item())
sigma2 = rss / max(n * self.n_pairs * self.n_aggs * N - N, 1)
for i in range(N):
se_i = math.sqrt(max(sigma2 * gram_inv[i, i], 1e-12))
z = w[i, :] / se_i
p = 0.5 * (1.0 - scipy.special.erf(z / math.sqrt(2.0)))
pvals[i, :] = np.clip(p, 0.0, 1.0)
return pvals
[docs] def evaluate(self, *, exclude_self: bool = True) -> pd.DataFrame:
'''
Build edge score DataFrame from current ``W``.
Returns columns compatible with ``MagnitudeEdgeInferer.evaluate``:
``src_func, dst_func, src_idx, dst_idx, score, has_edge, p_value, q_value``.
'''
scores = self.score_matrix()
pvals = self._coefficient_pvalues()
rows = []
for i, src in enumerate(self.function_nodes):
for j, dst in enumerate(self.function_nodes):
if exclude_self and i == j:
continue
rows.append({
'src_func': src,
'dst_func': dst,
'src_idx': i,
'dst_idx': j,
'score': scores[i, j],
'has_edge': (src, dst) in self.edges,
'p_value': pvals[i, j],
})
res = pd.DataFrame(rows)
res['q_value'] = self._bh_fdr(res['p_value'].to_numpy())
res = res.sort_values('score', ascending=False, na_position='last').reset_index(drop=True)
return res
[docs] def evaluate_against(
self,
positive_edges: set[tuple[str, str]] | list[tuple[str, str]],
*,
top_k: tuple[int, ...] = (1, 3, 5),
) -> dict[str, float]:
'''
Score held-out edges against non-edges using current ``W``.
Returns global ROC-AUC plus within-target MRR and top@k rates.
'''
from sklearn.metrics import roc_auc_score
pos = {(s, d) for s, d in positive_edges}
if not pos:
raise ValueError('positive_edges is empty.')
res = self.evaluate(exclude_self=True)
pos = {(s, d) for s, d in positive_edges}
def _category(row):
pair = (row['src_func'], row['dst_func'])
if pair in pos:
return 'held_out'
if pair in self.edges:
return 'in_graph'
return 'non_edge'
res = res.assign(edge_category=res.apply(_category, axis=1))
pos_scores = res.loc[res.edge_category == 'held_out', 'score'].dropna().values
neg_scores = res.loc[res.edge_category == 'non_edge', 'score'].dropna().values
out: dict[str, float] = {}
if len(pos_scores) > 0 and len(neg_scores) > 0:
y_true = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
y_score = np.concatenate([pos_scores, neg_scores])
out['auc'] = float(roc_auc_score(y_true, y_score))
else:
out['auc'] = float('nan')
_, summary = self.evaluate_target_ranking(
res, positive_edges=pos, score_col='score', top_k=top_k,
)
out.update(summary)
return out
[docs] @staticmethod
def evaluate_target_ranking(
res: pd.DataFrame,
positive_edges: set[tuple[str, str]] | list[tuple[str, str]],
score_col: str = 'score',
top_k: tuple[int, ...] = (1, 3, 5),
) -> tuple[pd.DataFrame, dict[str, float]]:
'''Delegate to ``MagnitudeEdgeInferer.evaluate_target_ranking``.'''
return MagnitudeEdgeInferer.evaluate_target_ranking(
res, positive_edges=positive_edges, score_col=score_col, top_k=top_k,
)
[docs] def maybe_save_best(self, metric: float, mode: str = 'max') -> bool:
'''Save ``state_dict`` if ``metric`` improves over the previous best.'''
if metric is None or (isinstance(metric, float) and math.isnan(metric)):
return False
improved = (
self._best_metric is None
or (mode == 'max' and metric > self._best_metric)
or (mode == 'min' and metric < self._best_metric)
)
if improved:
self._best_metric = float(metric)
self._best_state = copy.deepcopy(self.state_dict())
return True
return False
[docs] def load_best(self) -> None:
'''Restore weights from the best validation checkpoint.'''
if self._best_state is None:
raise RuntimeError('No best checkpoint saved; call maybe_save_best() first.')
self.load_state_dict(self._best_state)