'''
Post-hoc Tier-0 edge inference via node2vec on the augmented function graph.
After training a GSNN, ``MagnitudeEdgeInferer`` accumulates activation/gradient
magnitude correlations. High-scoring non-edges are mined as inferred positives
and pooled with the kept-graph edges into a single augmented graph. A single
shared embedding table is learned by skip-gram with negative sampling on
random walks, so all nodes live in the same space - kept and inferred edges
contribute equally to neighborhood structure.
Held-out edges are scored by ``<emb[i], emb[j]>``.
See tutorial 17 and ``docs/notes/edge_inference_notes.md`` section 4.
'''
from __future__ import annotations
import copy
import math
from collections.abc import Iterable
from typing import Literal
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from gsnn.optim.MagnitudeEdgeInferer import MagnitudeEdgeInferer
[docs]class MagnitudeEdgeKGE(nn.Module):
'''
Post-hoc node2vec edge inferrer for function -> function edges.
Consumes a fitted ``MagnitudeEdgeInferer``, mines inferred positive edges
from its correlation scores, builds an augmented directed graph from
(kept + inferred) edges, and trains a single node embedding table by
skip-gram with negative sampling on random walks. Parameter count is
``O(N * d)``.
Parameters
----------
mei : MagnitudeEdgeInferer
Fitted inferrer with accumulated statistics (``mei.n >= 3``).
embedding_dim : int
Embedding dimension.
score : {'corr', 'partial'}
MEI score column used to mine inferred positives.
layer_agg : {'mean', 'max'}
MEI layer aggregation for the score matrix.
mining_strategy : {'fdr', 'topk_per_target'}
How to select inferred positives from the MEI score table.
fdr_alpha : float
BH-FDR threshold when ``mining_strategy='fdr'``.
top_k_per_target : int
Top sources per target when ``mining_strategy='topk_per_target'``.
exclude_edges : iterable of (src, dst)
Held-out val/test edges to remove from inferred positives (anti-leakage).
walks_per_node : int
Number of random walks starting from each node per epoch.
walk_length : int
Length of each random walk (number of nodes).
window_size : int
Skip-gram context window (sliding distance within a walk).
n_negatives : int
Negative samples per positive (center, context) pair.
walk_undirected : bool
If True, treat the augmented graph as undirected for walk traversal.
Walks rarely die at sinks, so coverage is better. Skip-gram positives
are still emitted symmetrically. Default True.
walk_corr_weighted : bool
If True, transition probability ``P(j | i)`` along the walk is
proportional to ``max(corr[i, j], 0) ** walk_alpha`` rather than
uniform over neighbors. Brings back MEI's continuous signal that
the binary mining step otherwise discards. Default True.
walk_alpha : float
Power applied to ``max(corr, 0)`` before normalizing into transition
probabilities. ``alpha=1.0`` is linear; larger values concentrate
walks on high-correlation edges; smaller values flatten toward
uniform. Default 1.0.
kept_edge_weight : float or None
Walk weight assigned to kept (true) function-function edges. If None,
defaults to the maximum inferred-edge weight, so kept edges are at
least as likely to be traversed as the strongest inferred edge.
lr, weight_decay : float
Optimizer settings.
'''
_NODE_TYPE = 'function'
def __init__(
self,
mei: MagnitudeEdgeInferer,
*,
embedding_dim: int = 64,
score: Literal['corr', 'partial'] = 'corr',
layer_agg: Literal['mean', 'max'] = 'max',
mining_strategy: Literal['fdr', 'topk_per_target'] = 'fdr',
fdr_alpha: float = 0.05,
top_k_per_target: int = 5,
exclude_edges: Iterable[tuple[str, str]] = (),
walks_per_node: int = 20,
walk_length: int = 10,
window_size: int = 5,
n_negatives: int = 5,
walk_undirected: bool = True,
walk_corr_weighted: bool = True,
walk_alpha: float = 1.0,
kept_edge_weight: float | None = None,
lr: float = 1e-2,
weight_decay: float = 1e-4,
):
super().__init__()
if score not in ('corr', 'partial'):
raise ValueError(f"Unknown score {score!r}. Use 'corr' or 'partial'.")
if layer_agg not in ('mean', 'max'):
raise ValueError(f"Unknown layer_agg {layer_agg!r}. Use 'mean' or 'max'.")
if mining_strategy not in ('fdr', 'topk_per_target'):
raise ValueError(
f"Unknown mining_strategy {mining_strategy!r}. "
"Use 'fdr' or 'topk_per_target'."
)
if mei.n < 3:
raise RuntimeError(
f'MagnitudeEdgeKGE requires a fitted MEI with n >= 3; got n={mei.n}. '
'Call mei.fit(dataloader) first.'
)
self.__dict__['mei'] = mei
self.data = mei.data
self.function_nodes = list(mei.function_nodes)
self.N = mei.N
self.edges = set(mei.edges)
self.score = score
self.layer_agg = layer_agg
self.mining_strategy = mining_strategy
self.fdr_alpha = float(fdr_alpha)
self.top_k_per_target = int(top_k_per_target)
self.exclude_edges = {(s, d) for s, d in exclude_edges}
self.embedding_dim = int(embedding_dim)
self.walks_per_node = int(walks_per_node)
self.walk_length = int(walk_length)
self.window_size = int(window_size)
self.n_negatives = int(n_negatives)
self.walk_undirected = bool(walk_undirected)
self.walk_corr_weighted = bool(walk_corr_weighted)
self.walk_alpha = float(walk_alpha)
self.kept_edge_weight = (
float(kept_edge_weight) if kept_edge_weight is not None else None
)
self.emb = nn.Embedding(self.N, self.embedding_dim)
nn.init.normal_(self.emb.weight, mean=0.0, std=0.1)
self.optim = torch.optim.AdamW(
self.parameters(), lr=lr, weight_decay=weight_decay,
)
self._mei_df = mei.evaluate(layer_agg=layer_agg, score=score)
self._prepare_positives()
self._build_adjacency()
self._best_state: dict | None = None
self._best_metric: float | None = None
def _prepare_positives(self) -> None:
'''Mine inferred positives from MEI scores; materialize edge tensors.'''
df = self._mei_df.copy()
df = df.dropna(subset=['corr'])
df = df[df['src_idx'] != df['dst_idx']]
kept_mask = df.apply(
lambda r: (r['src_func'], r['dst_func']) in self.edges, axis=1,
)
exclude_mask = df.apply(
lambda r: (r['src_func'], r['dst_func']) in self.exclude_edges, axis=1,
)
candidates = df[~kept_mask & ~exclude_mask].copy()
if self.mining_strategy == 'fdr':
inferred = candidates[candidates['q_value'] <= self.fdr_alpha].copy()
else:
inferred = (
candidates
.sort_values('corr', ascending=False)
.groupby('dst_func', sort=False, group_keys=False)
.head(self.top_k_per_target)
)
inferred = inferred.sort_values('corr', ascending=False).reset_index(drop=True)
self._inferred_df = inferred
inf_heads = inferred['src_idx'].to_numpy(dtype=np.int64)
inf_tails = inferred['dst_idx'].to_numpy(dtype=np.int64)
self.inferred_heads = torch.tensor(inf_heads, dtype=torch.long)
self.inferred_tails = torch.tensor(inf_tails, dtype=torch.long)
true_pairs = sorted(self.edges)
true_heads = [self.function_nodes.index(s) for s, _ in true_pairs]
true_tails = [self.function_nodes.index(d) for _, d in true_pairs]
self.true_heads = torch.tensor(true_heads, dtype=torch.long)
self.true_tails = torch.tensor(true_tails, dtype=torch.long)
self.pos_heads = torch.cat([self.true_heads, self.inferred_heads])
self.pos_tails = torch.cat([self.true_tails, self.inferred_tails])
def _build_adjacency(self) -> None:
'''Build per-node weighted neighbor lists for walk generation.
Each edge weight is ``max(MEI corr, 0) ** walk_alpha`` for inferred
edges; kept edges receive ``kept_edge_weight`` (defaulting to the
max inferred-edge weight). Per-node weights are normalized into
transition probabilities; nodes with all-zero weights fall back to
uniform.
'''
N = self.N
corr_lookup: dict[tuple[int, int], float] = {}
if self.walk_corr_weighted:
mei_df = self._mei_df.dropna(subset=['corr'])
corr_lookup = {
(int(r['src_idx']), int(r['dst_idx'])): float(r['corr'])
for _, r in mei_df.iterrows()
}
def _w_from_corr(c: float) -> float:
return float(max(c, 0.0)) ** self.walk_alpha
inferred_pairs = list(zip(
self.inferred_heads.tolist(), self.inferred_tails.tolist(),
))
true_pairs = list(zip(
self.true_heads.tolist(), self.true_tails.tolist(),
))
if self.walk_corr_weighted:
inferred_w = [_w_from_corr(corr_lookup.get((i, j), 0.0))
for i, j in inferred_pairs]
if self.kept_edge_weight is not None:
kept_w_default = self.kept_edge_weight
elif inferred_w:
kept_w_default = max(inferred_w)
else:
kept_w_default = 1.0
true_w = [
max(_w_from_corr(corr_lookup.get((i, j), 0.0)), kept_w_default)
for i, j in true_pairs
]
else:
inferred_w = [1.0] * len(inferred_pairs)
true_w = [1.0] * len(true_pairs)
nbr_dicts: list[dict[int, float]] = [dict() for _ in range(N)]
for (h, t), w in zip(true_pairs + inferred_pairs, true_w + inferred_w):
nbr_dicts[h][t] = nbr_dicts[h].get(t, 0.0) + w
if self.walk_undirected:
nbr_dicts[t][h] = nbr_dicts[t].get(h, 0.0) + w
self._adj: list[np.ndarray] = []
self._adj_p: list[np.ndarray] = []
for i in range(N):
d = nbr_dicts[i]
if not d:
self._adj.append(np.empty(0, dtype=np.int64))
self._adj_p.append(np.empty(0, dtype=np.float64))
continue
keys = np.asarray(sorted(d.keys()), dtype=np.int64)
weights = np.asarray([d[k] for k in keys], dtype=np.float64)
total = weights.sum()
if total > 0:
p = weights / total
else:
p = np.full(len(keys), 1.0 / len(keys), dtype=np.float64)
self._adj.append(keys)
self._adj_p.append(p)
self._has_neighbors = np.asarray(
[arr.size > 0 for arr in self._adj], dtype=bool,
)
def _generate_walks(self, rng: np.random.Generator) -> np.ndarray:
'''Random walks of shape ``(walks_per_node * N, walk_length)``.'''
N = self.N
L = self.walk_length
K = self.walks_per_node
starts = np.tile(np.arange(N, dtype=np.int64), K)
rng.shuffle(starts)
walks = np.full((starts.size, L), -1, dtype=np.int64)
walks[:, 0] = starts
for step in range(1, L):
current = walks[:, step - 1]
valid = current >= 0
for w_idx in np.where(valid)[0]:
cur = current[w_idx]
nbrs = self._adj[cur]
if nbrs.size == 0:
continue
p = self._adj_p[cur]
walks[w_idx, step] = rng.choice(nbrs, p=p)
return walks
def _walks_to_pairs(self, walks: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
'''Convert walks to ``(center, context)`` skip-gram pairs.'''
L = walks.shape[1]
w = self.window_size
centers = []
contexts = []
for offset in range(1, w + 1):
c = walks[:, : L - offset]
ct = walks[:, offset:]
valid = (c >= 0) & (ct >= 0) & (c != ct)
centers.append(c[valid])
contexts.append(ct[valid])
centers.append(ct[valid])
contexts.append(c[valid])
if not centers:
return (
np.empty(0, dtype=np.int64),
np.empty(0, dtype=np.int64),
)
return (
np.concatenate(centers).astype(np.int64),
np.concatenate(contexts).astype(np.int64),
)
def _skipgram_loss(
self,
centers: torch.Tensor,
contexts: torch.Tensor,
) -> torch.Tensor:
'''Skip-gram with negative sampling: log-sigmoid losses.'''
device = centers.device
c = self.emb(centers)
p = self.emb(contexts)
pos_score = (c * p).sum(dim=-1)
pos_loss = -F.logsigmoid(pos_score)
if self.n_negatives > 0:
neg = torch.randint(
0, self.N, (centers.size(0), self.n_negatives), device=device,
)
n = self.emb(neg)
neg_score = torch.einsum('bd,bkd->bk', c, n)
neg_loss = -F.logsigmoid(-neg_score).mean(dim=-1)
else:
neg_loss = torch.zeros_like(pos_loss)
return (pos_loss + neg_loss).mean()
[docs] def fit(
self,
n_epochs: int = 100,
batch_size: int = 2048,
validation_edges: Iterable[tuple[str, str]] | None = None,
verbose: bool = True,
seed: int = 0,
) -> dict[str, list[float]]:
'''
Train embeddings via skip-gram with negative sampling.
Walks are regenerated each epoch.
Returns
-------
history : dict
``train_loss`` per epoch, optional ``val_auc`` per epoch.
'''
if not self._has_neighbors.any():
raise RuntimeError('Augmented graph has no edges; cannot train.')
device = next(self.parameters()).device
val_edges = list(validation_edges) if validation_edges is not None else None
rng = np.random.default_rng(seed)
history: dict[str, list[float]] = {'train_loss': []}
if val_edges is not None:
history['val_auc'] = []
for epoch in range(n_epochs):
self.train()
walks = self._generate_walks(rng)
centers_np, contexts_np = self._walks_to_pairs(walks)
n_pairs = centers_np.size
if n_pairs == 0:
history['train_loss'].append(0.0)
if val_edges is not None:
history['val_auc'].append(float('nan'))
continue
perm = rng.permutation(n_pairs)
centers_np = centers_np[perm]
contexts_np = contexts_np[perm]
centers = torch.from_numpy(centers_np).to(device)
contexts = torch.from_numpy(contexts_np).to(device)
epoch_loss = 0.0
n_steps = 0
for start in range(0, n_pairs, batch_size):
end = min(start + batch_size, n_pairs)
loss = self._skipgram_loss(centers[start:end], contexts[start:end])
self.optim.zero_grad()
loss.backward()
self.optim.step()
epoch_loss += float(loss.item())
n_steps += 1
history['train_loss'].append(epoch_loss / max(n_steps, 1))
if val_edges is not None:
val_metrics = self.evaluate_against(val_edges)
val_auc = val_metrics['auc']
history['val_auc'].append(val_auc)
self.maybe_save_best(val_auc)
if verbose and (
epoch == 0
or (epoch + 1) % max(1, n_epochs // 10) == 0
or epoch == n_epochs - 1
):
msg = (
f'epoch {epoch + 1:4d}/{n_epochs} | pairs {n_pairs:6d} | '
f'loss {history["train_loss"][-1]:.4f}'
)
if val_edges is not None:
msg += f' | val AUC {history["val_auc"][-1]:.3f}'
print(msg)
return history
[docs] def score_matrix(self) -> np.ndarray:
'''Return (N, N) edge score matrix from node embeddings.'''
self.eval()
with torch.inference_mode():
scores = self.emb.weight @ self.emb.weight.T
return scores.detach().cpu().numpy().astype(np.float64)
[docs] def evaluate(self, *, exclude_self: bool = True) -> pd.DataFrame:
'''
Build edge score DataFrame from node embeddings.
Columns: ``src_func, dst_func, src_idx, dst_idx, score, has_edge``.
'''
score_mat = self.score_matrix()
rows = []
for i, src in enumerate(self.function_nodes):
for j, dst in enumerate(self.function_nodes):
if exclude_self and i == j:
continue
rows.append({
'src_func': src,
'dst_func': dst,
'src_idx': i,
'dst_idx': j,
'score': score_mat[i, j],
'has_edge': (src, dst) in self.edges,
})
res = pd.DataFrame(rows)
res = res.sort_values('score', ascending=False, na_position='last').reset_index(drop=True)
return res
[docs] def evaluate_against(
self,
positive_edges: set[tuple[str, str]] | list[tuple[str, str]],
*,
top_k: tuple[int, ...] = (1, 3, 5),
) -> dict[str, float]:
'''ROC-AUC and within-target ranking metrics on held-out edges.'''
from sklearn.metrics import roc_auc_score
pos = {(s, d) for s, d in positive_edges}
if not pos:
raise ValueError('positive_edges is empty.')
res = self.evaluate(exclude_self=True)
def _category(row):
pair = (row['src_func'], row['dst_func'])
if pair in pos:
return 'held_out'
if pair in self.edges:
return 'in_graph'
return 'non_edge'
res = res.assign(edge_category=res.apply(_category, axis=1))
pos_scores = res.loc[res.edge_category == 'held_out', 'score'].dropna().values
neg_scores = res.loc[res.edge_category == 'non_edge', 'score'].dropna().values
out: dict[str, float] = {}
if len(pos_scores) > 0 and len(neg_scores) > 0:
y_true = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
y_score = np.concatenate([pos_scores, neg_scores])
out['auc'] = float(roc_auc_score(y_true, y_score))
else:
out['auc'] = float('nan')
_, summary = self.evaluate_target_ranking(
res, positive_edges=pos, score_col='score', top_k=top_k,
)
out.update(summary)
return out
[docs] @staticmethod
def evaluate_target_ranking(
res: pd.DataFrame,
positive_edges: set[tuple[str, str]] | list[tuple[str, str]],
score_col: str = 'score',
top_k: tuple[int, ...] = (1, 3, 5),
) -> tuple[pd.DataFrame, dict[str, float]]:
"""Delegate to ``MagnitudeEdgeInferer.evaluate_target_ranking``."""
return MagnitudeEdgeInferer.evaluate_target_ranking(
res, positive_edges=positive_edges, score_col=score_col, top_k=top_k,
)
[docs] def maybe_save_best(self, metric: float, mode: str = 'max') -> bool:
'''Save ``state_dict`` if ``metric`` improves over the previous best.'''
if metric is None or (isinstance(metric, float) and math.isnan(metric)):
return False
improved = (
self._best_metric is None
or (mode == 'max' and metric > self._best_metric)
or (mode == 'min' and metric < self._best_metric)
)
if improved:
self._best_metric = float(metric)
self._best_state = copy.deepcopy(self.state_dict())
return True
return False
[docs] def load_best(self) -> None:
'''Restore weights from the best validation checkpoint.'''
if self._best_state is None:
raise RuntimeError('No best checkpoint saved; call maybe_save_best() first.')
self.load_state_dict(self._best_state)