gsnn.optim.OutputEdgeInferer
Lightweight optimizer to infer output edges from intermediate GSNN node activations.
This module estimates a per-function-node linear mapping from channel activations to each output node using a simple batched regression. The learned weights can be interpreted as evidence for candidate edges from function nodes to output nodes.
Assumptions: - The GSNN model exposes get_node_activations(x, agg=…) returning a dict
mapping function node names to tensors of shape (B, C), where B is batch size and C is the channel dimension for that node.
The target y has shape (B, O), where O is the number of output nodes.
Notes: - Provides fit(dataloader, model, epochs=…) for training.
Classes
|
Learns per-function-node linear mappings from channel activations to outputs. |
- class gsnn.optim.OutputEdgeInferer.OutputEdgeInferer(*args: Any, **kwargs: Any)[source]
Bases:
ModuleLearns per-function-node linear mappings from channel activations to outputs.
Each function node i has weights W[i] with shape (C, O), producing per-node predictions that can be compared to ground truth outputs to score candidate edges.
- evaluate(dataloader, model, device='cpu', verbose=True)[source]
Evaluate per-node predictive power across a full dataset using streaming statistics.
- Parameters:
dataloader – Iterable yielding tuples (x, y) with shapes x=?, y=(B, O).
model – GSNN model exposing get_node_activations(x, agg=…).
- Returns:
func_node, output_node, mse, r2, r, has_edge
model_mse, model_r2, model_r
r2_gain, r_gain, mse_gain
p_value: one-sided p-value testing improvement (r2_gain > 0), via paired mean-squared-error test with normal approximation over samples.
q_value: Benjamini-Hochberg FDR-adjusted p-value.
snr: Signal-to-Noise Ratio (Var(predictions) / MSE). Higher values indicate stronger signal from function node to output.
l1_norm: L1 norm of weights (sparsity-promoting). Lower values = sparser model.
l2_norm: L2 norm of weights (regularization). Lower values = smaller weights.
sparsity: Fraction of weights close to zero. Higher values = sparser model.
eff_rank: Effective rank measure. Lower values = simpler model.
- Return type:
pandas.DataFrame with columns
p-value meaning: - Null hypothesis: the edge-specific predictor does not reduce expected MSE vs the
baseline model for this output (i.e., r2_gain <= 0).
Alternative: the edge-specific predictor reduces expected MSE (r2_gain > 0).
We compute per-sample squared-error differences and apply a one-sided normal approximation to the mean difference. This is tractable and aligns with r2_gain since r2_gain = (mse_baseline - mse_node) / Var(y).
FDR: We report q-values (BH-adjusted p-values) over all (func, output) pairs.
- fit(dataloader, model, epochs=None, device='cpu', verbose=True)[source]
Fit the per-node linear mappings using batches from a dataloader.
- Parameters:
dataloader – Iterable yielding tuples (x, y) with shapes x=?, y=(B, O).
model – GSNN model exposing get_node_activations(x, agg=…).
epochs – Optional override for number of epochs. Defaults to self.epochs.
- Returns:
List of average epoch losses.