Source code for gsnn.optim.OutputEdgeInferer

'''
Lightweight optimizer to infer output edges from intermediate GSNN node activations.

This module estimates a per-function-node linear mapping from channel activations to
each output node using a simple batched regression. The learned weights can be
interpreted as evidence for candidate edges from function nodes to output nodes.

Assumptions:
- The GSNN `model` exposes `get_node_activations(x, agg=...)` returning a dict
  mapping function node names to tensors of shape (B, C), where B is batch size
  and C is the channel dimension for that node.
- The target `y` has shape (B, O), where O is the number of output nodes.

Notes:
- Provides `fit(dataloader, model, epochs=...)` for training.
'''

import torch
import numpy as np
import pandas as pd
from sklearn.metrics import explained_variance_score
import math
import scipy.special


def _safe_corrcoef(a, b):
            if a.size == 0 or b.size == 0:
                return np.nan
            if np.std(a) == 0 or np.std(b) == 0:
                return np.nan
            return float(np.corrcoef(a, b)[0, 1])

def _safe_evs(y_true, y_pred):
    try:
        if y_true.size == 0 or y_pred.size == 0:
            return np.nan
        # If y_true is constant, EVS is undefined; return 0 to be conservative
        if np.allclose(np.std(y_true), 0.0):
            return 0.0
        return float(explained_variance_score(y_true, y_pred))
    except Exception:
        return np.nan


[docs]class OutputEdgeInferer(torch.nn.Module): ''' Learns per-function-node linear mappings from channel activations to outputs. Each function node i has weights `W[i]` with shape (C, O), producing per-node predictions that can be compared to ground truth outputs to score candidate edges. ''' def __init__(self, data, channels, lr=1e-2, wd=1e-2, epochs=100, agg='last', use_batchnorm=False, bn_affine=False, tol=1e-6, patience=10): ''' Initialize the edge inferrer. Args: data: Dataset/graph container exposing `node_names_dict` and `edge_index_dict`. channels: Channel dimension C for function node activations. lr: Learning rate for Adam optimizer. wd: Weight decay (L2) for Adam optimizer. epochs: Number of epochs to fit over the provided dataloader. agg: Aggregation key passed to `model.get_node_activations`. use_batchnorm: If True, apply vectorized per-node, per-channel normalization with running mean/variance (BatchNorm-like behavior) using a single fused op. bn_affine: If True, learn a per-node, per-channel scale/shift. tol: Minimum improvement in epoch loss to reset patience (early stopping). patience: Number of epochs without sufficient improvement before stopping. ''' super().__init__() self.data = data self.agg = agg self.epochs = epochs self.use_batchnorm = use_batchnorm self.bn_affine = bn_affine self.tol = float(tol) self.patience = int(patience) num_function_nodes = len(data.node_names_dict['function']) N = num_function_nodes # number of function nodes O = len(data.node_names_dict['output']) C = channels # Each function node i has a weight matrix W[i] of shape (C, O) self.W = torch.nn.Parameter(torch.empty(N, C, O)) for i in range(N): torch.nn.init.xavier_uniform_(self.W[i], gain=1.0) # Optional per-node normalization across channels with running stats (vectorized) if self.use_batchnorm: self.register_buffer('running_mean', torch.zeros(C, N)) # (C, N) self.register_buffer('running_var', torch.ones(C, N)) # (C, N) self.register_buffer('bn_num_batches_tracked', torch.tensor(0, dtype=torch.long)) self.bn_momentum = 0.1 self.bn_eps = 1e-5 if self.bn_affine: self.bn_gamma = torch.nn.Parameter(torch.ones(C, N)) self.bn_beta = torch.nn.Parameter(torch.zeros(C, N)) else: self.register_buffer('bn_gamma', torch.ones(C, N)) self.register_buffer('bn_beta', torch.zeros(C, N)) else: self.running_mean = None self.running_var = None # Optimizer after all parameters are registered (including BN if affine) self.optim = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd) # Build existing edge set for reference; handle torch or numpy index types edge_index = data.edge_index_dict['function', 'to', 'output'] if hasattr(edge_index, 'detach'): edge_arr = edge_index.T.detach().cpu().numpy() else: edge_arr = np.asarray(edge_index.T) self.edges = set( [ (data.node_names_dict['function'][i], data.node_names_dict['output'][j]) for (i, j) in edge_arr ] )
[docs] def fit(self, dataloader, model, epochs=None, device='cpu', verbose=True): ''' Fit the per-node linear mappings using batches from a dataloader. Args: dataloader: Iterable yielding tuples (x, y) with shapes x=?, y=(B, O). model: GSNN model exposing `get_node_activations(x, agg=...)`. epochs: Optional override for number of epochs. Defaults to `self.epochs`. Returns: List of average epoch losses. ''' if epochs is None: epochs = self.epochs loss_history = [] best_loss = float('inf') epochs_without_improve = 0 # Ensure BN runs in training mode to update running stats self.train() self = self.to(device) if verbose: print(f"Fitting OutputEdgeInferer on {device}...") if verbose: print('# parameters: ', sum(p.numel() for p in self.parameters())) for _ in range(epochs): epoch_losses = [] for i, (x, y) in enumerate(dataloader): x = x.to(device) y = y.to(device) with torch.no_grad(): a_dict = model.get_node_activations(x, agg=self.agg) # Stack function node activations: list of (B, C) -> (B, C, N) a = torch.stack( [a_dict[node] for node in self.data.node_names_dict['function']], dim=-1, ).to(device) if self.use_batchnorm: a = self._normalize(a, training=True) # Forward: (B, C, N) -> (N, B, O) yhat = self.forward(a) # Expand targets to (N, B, O) to match per-node predictions y_expanded = y.unsqueeze(0).expand_as(yhat) # Mean over batch, sum over nodes and outputs mse = torch.mean((yhat - y_expanded) ** 2, dim=1).mean() self.optim.zero_grad() mse.backward() self.optim.step() epoch_losses.append(mse.detach().item()) if verbose: print(f"[batch {i}/{len(dataloader)} loss: {mse.detach().item()}]", end='\r') epoch_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0 loss_history.append(epoch_loss) # Early stopping using training loss if best_loss - epoch_loss > self.tol: best_loss = epoch_loss epochs_without_improve = 0 else: epochs_without_improve += 1 if epochs_without_improve >= self.patience: break if verbose: print(f'epoch {_} loss: {epoch_loss}') return loss_history
[docs] def forward(self, a): ''' Compute per-function-node linear maps to outputs. Args: a: Activation tensor of shape (B, C, N), where: - B: batch size - C: channels - N: number of function nodes Returns: Tensor of shape (N, B, O): per-node predictions for each output. ''' # Ensure activations are on same device as parameters a = a.to(self.W.device) # (B, C, N) -> (N, B, C) a = a.permute(2, 0, 1) # Batched matmul over nodes: (N, B, C) @ (N, C, O) -> (N, B, O) out = torch.bmm(a, self.W) return out
[docs] def evaluate(self, dataloader, model, device='cpu', verbose=True): ''' Evaluate per-node predictive power across a full dataset using streaming statistics. Args: dataloader: Iterable yielding tuples (x, y) with shapes x=?, y=(B, O). model: GSNN model exposing `get_node_activations(x, agg=...)`. Returns: pandas.DataFrame with columns: - func_node, output_node, mse, r2, r, has_edge - model_mse, model_r2, model_r - r2_gain, r_gain, mse_gain - p_value: one-sided p-value testing improvement (r2_gain > 0), via paired mean-squared-error test with normal approximation over samples. - q_value: Benjamini-Hochberg FDR-adjusted p-value. - snr: Signal-to-Noise Ratio (Var(predictions) / MSE). Higher values indicate stronger signal from function node to output. - l1_norm: L1 norm of weights (sparsity-promoting). Lower values = sparser model. - l2_norm: L2 norm of weights (regularization). Lower values = smaller weights. - sparsity: Fraction of weights close to zero. Higher values = sparser model. - eff_rank: Effective rank measure. Lower values = simpler model. p-value meaning: - Null hypothesis: the edge-specific predictor does not reduce expected MSE vs the baseline model for this output (i.e., r2_gain <= 0). - Alternative: the edge-specific predictor reduces expected MSE (r2_gain > 0). - We compute per-sample squared-error differences and apply a one-sided normal approximation to the mean difference. This is tractable and aligns with r2_gain since r2_gain = (mse_baseline - mse_node) / Var(y). FDR: We report q-values (BH-adjusted p-values) over all (func, output) pairs. ''' # Use running stats for normalization during evaluation self.eval() self = self.to(device) if verbose: print(f"Evaluating OutputEdgeInferer on {device}...") function_nodes = self.data.node_names_dict['function'] output_nodes = self.data.node_names_dict['output'] N = len(function_nodes) O = len(output_nodes) # Initialize vectorized streaming statistics # Shape: (N, O) for node pairs, (O,) for model stats node_n = np.zeros((N, O), dtype=np.int64) node_sum_x = np.zeros((N, O), dtype=np.float64) node_sum_y = np.zeros((N, O), dtype=np.float64) node_sum_x2 = np.zeros((N, O), dtype=np.float64) node_sum_y2 = np.zeros((N, O), dtype=np.float64) node_sum_xy = np.zeros((N, O), dtype=np.float64) node_sum_se = np.zeros((N, O), dtype=np.float64) node_sum_diff = np.zeros((N, O), dtype=np.float64) node_sum_diff2 = np.zeros((N, O), dtype=np.float64) # Model statistics (O,) model_n = np.zeros(O, dtype=np.int64) model_sum_x = np.zeros(O, dtype=np.float64) model_sum_y = np.zeros(O, dtype=np.float64) model_sum_x2 = np.zeros(O, dtype=np.float64) model_sum_y2 = np.zeros(O, dtype=np.float64) model_sum_xy = np.zeros(O, dtype=np.float64) model_sum_se = np.zeros(O, dtype=np.float64) for bi, (x, y) in enumerate(dataloader): x = x.to(device) y = y.to(device) with torch.inference_mode(): a_dict = model.get_node_activations(x, agg=self.agg) a = torch.stack([a_dict[node] for node in function_nodes], dim=-1).to(device) if self.use_batchnorm: a = self._normalize(a, training=False) yhat_nodes = self.forward(a) # (N, B, O) yhat_model = model(x) # (B, O) y_np = y.detach().cpu().numpy() # (B, O) yhat_model_np = yhat_model.detach().cpu().numpy() # (B, O) yhat_nodes_np = yhat_nodes.detach().cpu().numpy() # (N, B, O) B = y_np.shape[0] # Vectorized model statistics update model_n += B model_sum_x += np.sum(yhat_model_np, axis=0) # (O,) model_sum_y += np.sum(y_np, axis=0) # (O,) model_sum_x2 += np.sum(yhat_model_np ** 2, axis=0) # (O,) model_sum_y2 += np.sum(y_np ** 2, axis=0) # (O,) model_sum_xy += np.sum(yhat_model_np * y_np, axis=0) # (O,) model_sum_se += np.sum((yhat_model_np - y_np) ** 2, axis=0) # (O,) # Vectorized node statistics update using broadcasting # Reshape for broadcasting: y_np (B, O) -> (1, B, O), yhat_model_np (B, O) -> (1, B, O) y_broadcast = y_np[np.newaxis, :, :] # (1, B, O) yhat_model_broadcast = yhat_model_np[np.newaxis, :, :] # (1, B, O) # yhat_nodes_np is already (N, B, O) # Compute all statistics using broadcasting node_n += B # (N, O) += scalar node_sum_x += np.sum(yhat_nodes_np, axis=1) # (N, O) node_sum_y += np.sum(y_broadcast, axis=1) # (N, O) node_sum_x2 += np.sum(yhat_nodes_np ** 2, axis=1) # (N, O) node_sum_y2 += np.sum(y_broadcast ** 2, axis=1) # (N, O) node_sum_xy += np.sum(yhat_nodes_np * y_broadcast, axis=1) # (N, O) node_sum_se += np.sum((yhat_nodes_np - y_broadcast) ** 2, axis=1) # (N, O) # Compute paired differences for p-value: (node_pred - true)² - (model_pred - true)² node_errors = (yhat_nodes_np - y_broadcast) ** 2 # (N, B, O) model_errors = (yhat_model_broadcast - y_broadcast) ** 2 # (1, B, O) diff = node_errors - model_errors # (N, B, O) node_sum_diff += np.sum(diff, axis=1) # (N, O) node_sum_diff2 += np.sum(diff ** 2, axis=1) # (N, O) if verbose: print(f"[batch {bi}/{len(dataloader)}]", end='\r') # Compute final metrics from accumulated statistics # Vectorized model metrics computation model_mse = np.where(model_n > 0, model_sum_se / model_n, np.nan) # Vectorized correlation and explained variance for model model_mean_x = np.where(model_n > 0, model_sum_x / model_n, 0) model_mean_y = np.where(model_n > 0, model_sum_y / model_n, 0) model_var_x = np.where(model_n > 0, (model_sum_x2 / model_n) - model_mean_x ** 2, 0) model_var_y = np.where(model_n > 0, (model_sum_y2 / model_n) - model_mean_y ** 2, 0) model_cov_xy = np.where(model_n > 0, (model_sum_xy / model_n) - model_mean_x * model_mean_y, 0) # Correlation coefficient model_r = np.where( (model_n > 1) & (model_var_x > 0) & (model_var_y > 0), model_cov_xy / np.sqrt(model_var_x * model_var_y), np.where(model_n > 0, 0, np.nan) ) # Explained variance score model_r2 = np.where( (model_n > 1) & (model_var_y > 0), np.maximum(0.0, 1.0 - model_mse / model_var_y), np.where(model_n > 0, 0, np.nan) ) # Build model dataframe rdf = pd.DataFrame({ "output_node": output_nodes, "model_r2": model_r2, "model_r": model_r, "model_mse": model_mse }) # Vectorized node metrics computation node_mse = np.where(node_n > 0, node_sum_se / node_n, np.nan) # Vectorized correlation and explained variance for nodes node_mean_x = np.where(node_n > 0, node_sum_x / node_n, 0) node_mean_y = np.where(node_n > 0, node_sum_y / node_n, 0) node_var_x = np.where(node_n > 0, (node_sum_x2 / node_n) - node_mean_x ** 2, 0) node_var_y = np.where(node_n > 0, (node_sum_y2 / node_n) - node_mean_y ** 2, 0) node_cov_xy = np.where(node_n > 0, (node_sum_xy / node_n) - node_mean_x * node_mean_y, 0) # Correlation coefficient node_r = np.where( (node_n > 1) & (node_var_x > 0) & (node_var_y > 0), node_cov_xy / np.sqrt(node_var_x * node_var_y), np.where(node_n > 0, 0, np.nan) ) # Explained variance score node_r2 = np.where( (node_n > 1) & (node_var_y > 0), np.maximum(0.0, 1.0 - node_mse / node_var_y), np.where(node_n > 0, 0, np.nan) ) # Signal-to-Noise Ratio (SNR) computation for model selection # SNR = Var(predicted_output) / Var(residuals) = Var(predicted) / MSE # Higher SNR indicates stronger signal from function to output node_snr = np.where( (node_n > 1) & (node_mse > 0) & (node_var_x > 0), node_var_x / node_mse, 0.0 ) # Model complexity metrics using trained weights # Get weights on CPU for computation: self.W shape is (N, C, O) W_cpu = self.W.detach().cpu().numpy() # (N, C, O) # L1 norm (sparsity-promoting): sum of absolute weights per (function, output) node_l1_norm = np.sum(np.abs(W_cpu), axis=1) # (N, O) # L2 norm (weight magnitude): Euclidean norm per (function, output) node_l2_norm = np.sqrt(np.sum(W_cpu ** 2, axis=1)) # (N, O) # Weight sparsity: fraction of weights close to zero (< 1e-6) sparsity_threshold = 1e-6 node_sparsity = np.mean(np.abs(W_cpu) < sparsity_threshold, axis=1) # (N, O) # Effective rank: number of significant singular values (> 1% of max) node_eff_rank = np.zeros((N, O)) for i in range(N): for j in range(O): w_vec = W_cpu[i, :, j] # (C,) if np.any(np.abs(w_vec) > 1e-12): # avoid zero vectors # For 1D weight vector, effective rank is just whether it's non-zero node_eff_rank[i, j] = 1.0 if np.std(w_vec) > 1e-6 else 0.0 else: node_eff_rank[i, j] = 0.0 # Vectorized p-value computation node_d_mean = np.where(node_n > 0, node_sum_diff / node_n, 0) node_d_var = np.where(node_n > 0, (node_sum_diff2 / node_n) - node_d_mean ** 2, 0) node_d_std = np.sqrt(np.maximum(0, node_d_var)) # Compute z-scores node_z = np.where( (node_n >= 5) & (node_d_std > 0), node_d_mean / (node_d_std / np.sqrt(node_n)), 0 ) # One-sided normal CDF for alternative mean<0 # Using vectorized error function node_pval = np.where( (node_n >= 5) & (node_d_std > 0), 0.5 * (1.0 + scipy.special.erf(node_z / np.sqrt(2.0))), 1.0 ) node_pval = np.where(node_n > 0, node_pval, 1.0) # Build node dataframe using vectorized operations func_nodes_flat = [] output_nodes_flat = [] mse_flat = [] r2_flat = [] r_flat = [] has_edge_flat = [] pval_flat = [] snr_flat = [] l1_norm_flat = [] l2_norm_flat = [] sparsity_flat = [] eff_rank_flat = [] for i, fi in enumerate(function_nodes): for j, oj in enumerate(output_nodes): func_nodes_flat.append(fi) output_nodes_flat.append(oj) mse_flat.append(node_mse[i, j]) r2_flat.append(node_r2[i, j]) r_flat.append(node_r[i, j]) has_edge_flat.append((fi, oj) in self.edges) pval_flat.append(node_pval[i, j]) snr_flat.append(node_snr[i, j]) l1_norm_flat.append(node_l1_norm[i, j]) l2_norm_flat.append(node_l2_norm[i, j]) sparsity_flat.append(node_sparsity[i, j]) eff_rank_flat.append(node_eff_rank[i, j]) res = pd.DataFrame({ "func_node": func_nodes_flat, "output_node": output_nodes_flat, "mse": mse_flat, "r2": r2_flat, "r": r_flat, "has_edge": has_edge_flat, "p_value": pval_flat, "snr": snr_flat, "l1_norm": l1_norm_flat, "l2_norm": l2_norm_flat, "sparsity": sparsity_flat, "eff_rank": eff_rank_flat }) res = res.merge(rdf, on='output_node', how='left') res = res.assign( r2_gain=lambda x: x.r2 - x.model_r2, r_gain=lambda x: x.r - x.model_r, mse_gain=lambda x: x.mse - x.model_mse, ) # Benjamini–Hochberg q-values across all pairs (monotone BH) pvals = res["p_value"].values.astype(float) m = len(pvals) order = np.argsort(pvals) ranked = pvals[order] bh = ranked * m / (np.arange(1, m + 1)) bh = np.minimum.accumulate(bh[::-1])[::-1] bh = np.clip(bh, 0.0, 1.0) q_values = np.empty_like(bh) q_values[order] = bh res["q_value"] = q_values # Add SNR-based ranking within each output (1 = highest SNR/strongest signal) res = res.sort_values(['output_node', 'snr'], ascending=[True, False]).reset_index(drop=True) res = res.assign(snr_rank=lambda x: x.groupby('output_node').cumcount() + 1) # Add sparsity-based ranking within each output (1 = most sparse/simplest) res = res.sort_values(['output_node', 'sparsity'], ascending=[True, False]).reset_index(drop=True) res = res.assign(sparsity_rank=lambda x: x.groupby('output_node').cumcount() + 1) # Add within output rank based on q-value (1 = most significant) res = res.sort_values(['output_node', 'q_value'], ascending=[True, True]).reset_index(drop=True) res = res.assign(within_output_rank=lambda x: x.groupby('output_node').cumcount() + 1) # Sort by r2_gain for final output res = res.sort_values(by='r2_gain', ascending=False).reset_index(drop=True) return res
def _normalize(self, a, training=True): ''' Vectorized per-node, per-channel normalization with running stats. Args: a: Tensor (B, C, N) training: If True, update running stats using batch mean/var; else use running stats. Returns: Normalized tensor (B, C, N). ''' if not self.use_batchnorm: return a device = a.device dtype = a.dtype # Ensure buffers/params on device/dtype self.running_mean = self.running_mean.to(device=device, dtype=dtype) self.running_var = self.running_var.to(device=device, dtype=dtype) self.bn_gamma = self.bn_gamma.to(device=device, dtype=dtype) self.bn_beta = self.bn_beta.to(device=device, dtype=dtype) if training: batch_mean = a.mean(dim=0) # (C, N) batch_var = a.var(dim=0, unbiased=False) # (C, N) momentum = self.bn_momentum # Update running stats in-place self.running_mean.lerp_(batch_mean, momentum) self.running_var.lerp_(batch_var, momentum) if self.bn_num_batches_tracked is not None: self.bn_num_batches_tracked = self.bn_num_batches_tracked.to(device) self.bn_num_batches_tracked += 1 mean = batch_mean var = batch_var else: mean = self.running_mean var = self.running_var a = (a - mean) / torch.sqrt(var + self.bn_eps) a = a * self.bn_gamma + self.bn_beta return a