gsnn.optim.OutputEdgeInferer

Lightweight optimizer to infer output edges from intermediate GSNN node activations.

This module estimates a per-function-node linear mapping from channel activations to each output node using a simple batched regression. The learned weights can be interpreted as evidence for candidate edges from function nodes to output nodes.

Assumptions: - The GSNN model exposes get_node_activations(x, agg=…) returning a dict

mapping function node names to tensors of shape (B, C), where B is batch size and C is the channel dimension for that node.

  • The target y has shape (B, O), where O is the number of output nodes.

Notes: - Provides fit(dataloader, model, epochs=…) for training.

Classes

OutputEdgeInferer(*args, **kwargs)

Learns per-function-node linear mappings from channel activations to outputs.

class gsnn.optim.OutputEdgeInferer.OutputEdgeInferer(*args: Any, **kwargs: Any)[source]

Bases: Module

Learns per-function-node linear mappings from channel activations to outputs.

Each function node i has weights W[i] with shape (C, O), producing per-node predictions that can be compared to ground truth outputs to score candidate edges.

evaluate(dataloader, model, device='cpu', verbose=True)[source]

Evaluate per-node predictive power across a full dataset using streaming statistics.

Parameters:
  • dataloader – Iterable yielding tuples (x, y) with shapes x=?, y=(B, O).

  • model – GSNN model exposing get_node_activations(x, agg=…).

Returns:

  • func_node, output_node, mse, r2, r, has_edge

  • model_mse, model_r2, model_r

  • r2_gain, r_gain, mse_gain

  • p_value: one-sided p-value testing improvement (r2_gain > 0), via paired mean-squared-error test with normal approximation over samples.

  • q_value: Benjamini-Hochberg FDR-adjusted p-value.

  • snr: Signal-to-Noise Ratio (Var(predictions) / MSE). Higher values indicate stronger signal from function node to output.

  • l1_norm: L1 norm of weights (sparsity-promoting). Lower values = sparser model.

  • l2_norm: L2 norm of weights (regularization). Lower values = smaller weights.

  • sparsity: Fraction of weights close to zero. Higher values = sparser model.

  • eff_rank: Effective rank measure. Lower values = simpler model.

Return type:

pandas.DataFrame with columns

p-value meaning: - Null hypothesis: the edge-specific predictor does not reduce expected MSE vs the

baseline model for this output (i.e., r2_gain <= 0).

  • Alternative: the edge-specific predictor reduces expected MSE (r2_gain > 0).

  • We compute per-sample squared-error differences and apply a one-sided normal approximation to the mean difference. This is tractable and aligns with r2_gain since r2_gain = (mse_baseline - mse_node) / Var(y).

FDR: We report q-values (BH-adjusted p-values) over all (func, output) pairs.

fit(dataloader, model, epochs=None, device='cpu', verbose=True)[source]

Fit the per-node linear mappings using batches from a dataloader.

Parameters:
  • dataloader – Iterable yielding tuples (x, y) with shapes x=?, y=(B, O).

  • model – GSNN model exposing get_node_activations(x, agg=…).

  • epochs – Optional override for number of epochs. Defaults to self.epochs.

Returns:

List of average epoch losses.

forward(a)[source]

Compute per-function-node linear maps to outputs.

Parameters:

a – Activation tensor of shape (B, C, N), where: - B: batch size - C: channels - N: number of function nodes

Returns:

per-node predictions for each output.

Return type:

Tensor of shape (N, B, O)