gsnn.gsnn.models
Neural network model components for GSNN.
Classes
|
Applies normalization per individual channel using exponential moving averages. |
|
|
|
A batch-norm style module that: |
|
Applies normalization within each channel group using exponential moving averages. |
|
Layer normalization computed separately within each channel group. |
|
Applies Root Mean Square normalization within each channel group. |
|
Fully-connected baseline: Linear blocks with optional norm, activation, dropout. |
|
Node-wise channel attention. |
|
Small MLP applied independently to each node's channel vector inside a ResBlock. |
|
Per-pathway latent-factor auxiliary loss for GSNN. |
|
|
|
Aggregate scalar signals over function-function edges using stored signs (edge weights). |
|
Channel-wise softmax normalized within each channel group (stable softmax via per-group max shift). |
|
Fixed sparsity pattern linear layer; forward is batched message passing on the COO indices. |
- class gsnn.gsnn.models.ChannelEMANorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleApplies normalization per individual channel using exponential moving averages.
This normalization maintains running statistics for each channel independently but doesn’t use current batch statistics for normalization, making it very stable for small batch sizes.
- Parameters:
num_channels (int) – Number of channels to normalize.
eps (float) – Small value to avoid division by zero. Default: 1e-5
momentum (float) – Momentum for updating running statistics. Default: 0.1
affine (bool) – If True, applies learnable scale and bias per channel. Default: True
track_running_stats (bool) – If True, maintains running statistics. Default: True
- class gsnn.gsnn.models.GSNN(*args: Any, **kwargs: Any)[source]
Bases:
Module- forward(x, node_mask=None, edge_mask=None, ret_edge_out=False, e0=None, node_errs=None, x_fn=None)[source]
Implements the forward pass of the GSNN model.
The model first converts node features to edge features, then applies a sequence of sparse linear transformations constrained by the graph structure. Each layer consists of:
Input transformation (W_in)
Normalization (optional)
Nonlinearity
Output transformation (W_out)
Residual connection (optional)
- Parameters:
x (Tensor) – Input node features of shape
[batch_size, num_input_nodes].node_mask (Tensor, optional) – Boolean mask for function nodes of shape
[batch_size, num_nodes]. If provided, masks out specific function nodes during computation. (default:None)edge_mask (Tensor, optional) – Boolean mask for edges of shape
[batch_size, num_edges]. If provided, masks out specific edges during computation. (default:None)ret_edge_out (bool, optional) – If set to
True, return edge-level features instead of node-level features. (default:False)e0 (Tensor, optional) – Initial edge features of shape
[batch_size, num_edges]. Used for inferring input errors. (default:None)node_errs (List[Tensor], optional) – List of node errors per layer, each of shape
[batch_size, num_nodes]. Length must match number of layers. (default:None)x_fn (Tensor, optional) – Function node features of shape
[batch_size, num_function_nodes]. Used for computing node activity. (default:None)
- Returns:
If
ret_edge_out=False, returns node-level output features of shape[batch_size, num_output_nodes]. Otherwise, returns edge-level features of shape[batch_size, num_edges].- Return type:
Tensor
Example
>>> # Using the model from the class example >>> x = torch.randn(32, 2) # batch_size=32, num_input_nodes=2 >>> # Basic forward pass >>> out = model(x) >>> print(out.shape) # [32, 1] >>> # Get edge-level features >>> edge_out = model(x, ret_edge_out=True) >>> print(edge_out.shape) # [32, 4] (batch_size, num_edges) >>> # Using masks >>> node_mask = torch.ones(32, 4) # [batch_size, num_nodes] >>> edge_mask = torch.ones(32, 4) # [batch_size, num_edges] >>> out = model(x, node_mask=node_mask, edge_mask=edge_mask) >>> print(out.shape) # [32, 1]
- get_batch_params(B, device)[source]
Retrieves or computes batch-specific indexing parameters for sparse linear layers.
This method caches the batch parameters to avoid recomputing them for the same batch size. The parameters are used to efficiently perform batched sparse matrix operations.
- Parameters:
B (int) – Batch size.
device (torch.device) – Device on which to place the computed parameters.
- Returns:
- A tuple containing:
batched_indices_in (Tensor): Batched indices for input sparse linear layer
batched_indices_out (Tensor): Batched indices for output sparse linear layer
- Return type:
Example
>>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3) >>> # Get batch parameters for batch size 32 >>> batch_params = model.get_batch_params(32, torch.device('cuda')) >>> # Parameters are cached for subsequent calls >>> same_params = model.get_batch_params(32, torch.device('cuda')) >>> # Different batch size triggers recomputation >>> new_params = model.get_batch_params(64, torch.device('cuda'))
- get_node_attention(x)[source]
Return per-layer node-level attention weights.
- Parameters:
x (Tensor (B, num_input_nodes)) – Input features; typically supply a single sample (B=1).
- Returns:
Mapping from node name to a tensor of shape (L, B) with attention weights per layer (L) and batch element (B).
- Return type:
Dict[str, Tensor]
- prune(threshold=0.01, verbose=False)[source]
Prunes the model by removing channels with small weights.
This method removes channels whose maximum absolute weight value across all layers is below the specified threshold. This can significantly reduce model size while maintaining performance. Remember to reinitialize the optimizer after pruning if using during training.
- Parameters:
- Returns:
Number of parameters removed by pruning.
- Return type:
Example
>>> # Create a model with 16 channels per function node >>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3) >>> # Train the model... >>> # Prune channels with small weights >>> removed_params = model.prune(threshold=1e-2, verbose=True) >>> print(f'Removed {removed_params} parameters')
- class gsnn.gsnn.models.GroupBatchNorm(*args: Any, **kwargs: Any)[source]
Bases:
Module- A batch-norm style module that:
Partitions the C channels into groups via ‘channel_groups’.
Computes mean/var for each group across the entire batch dimension.
Maintains running stats for inference (if track_running_stats=True).
- class gsnn.gsnn.models.GroupEMANorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleApplies normalization within each channel group using exponential moving averages.
This normalization maintains running statistics but doesn’t use current batch statistics for normalization, making it very stable for small batch sizes.
- Parameters:
channel_groups (list or tensor) – Specifies which group each channel belongs to.
eps (float) – Small value to avoid division by zero. Default: 1e-5
momentum (float) – Momentum for updating running statistics. Default: 0.1
affine (bool) – If True, applies learnable scale and bias. Default: True
track_running_stats (bool) – If True, maintains running statistics. Default: True
- class gsnn.gsnn.models.GroupLayerNorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleLayer normalization computed separately within each channel group.
- class gsnn.gsnn.models.GroupRMSNorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleApplies Root Mean Square normalization within each channel group.
RMSNorm normalizes using only the RMS (root mean square) without mean centering, making it simpler and more stable than layer normalization, especially for small batch sizes.
- Parameters:
- class gsnn.gsnn.models.NN(*args: Any, **kwargs: Any)[source]
Bases:
ModuleFully-connected baseline: Linear blocks with optional norm, activation, dropout.
- class gsnn.gsnn.models.NodeAttention(*args: Any, **kwargs: Any)[source]
Bases:
ModuleNode-wise channel attention.
The layer learns a single scalar attention coefficient (alpha_{b,n}) per node n for every sample in the batch b. The coefficient is obtained by first aggregating the (optionally weighted) hidden channels that belong to the node and then normalising the aggregated scores across all nodes with a sigmoid gates per node (no cross-node normalization). The resulting attention weights can be:
Interpreted - (alpha_{b,n}) tells how important node n was for the current forward pass.
Applied - the coefficients are broadcast back to the individual channels that originated from the node and multiplied with the original activations, producing an attention-modulated output.
- Parameters:
channel_groups (Sequence[int] or Tensor) – A 1-D list/array mapping global channel index → node index. Length equals the total number of hidden channels across all nodes.
dropout (float, optional (default=0.0)) – Dropout probability applied to the node-level attention weights.
temperature (float, optional (default=1.0)) – Softmax temperature. Lower values produce sharper distributions.
Examples
>>> # Suppose we have 2 nodes with 3 channels each (total 6 channels) >>> ch_groups = [0, 0, 0, 1, 1, 1] >>> attn = NodeAttention(ch_groups, dropout=0.1) >>> x = torch.randn(8, 6) # (batch=8, channels=6) >>> out, alpha = attn(x, return_alpha=True) >>> out.shape # same shape as input torch.Size([8, 6]) >>> alpha.shape # one scalar per node torch.Size([8, 2])
- forward(x: torch.Tensor, *, return_alpha: bool = False)[source]
Apply node attention.
- Parameters:
x (Tensor of shape (B, C)) – Input activations ordered so that channels belonging to the same node are indexed according to channel_groups.
return_alpha (bool, optional (default=False)) – If True, the method returns a tuple
(out, alpha)wherealphais the attention matrix of shape (B, n_nodes).
- Returns:
The attention-modulated activations (and, optionally, the node coefficients).
- Return type:
Tensor or Tuple[Tensor, Tensor]
- class gsnn.gsnn.models.NodeMLP(*args: Any, **kwargs: Any)[source]
Bases:
ModuleSmall MLP applied independently to each node’s channel vector inside a ResBlock.
- forward(x: torch.Tensor) torch.Tensor[source]
xshape(batch, num_nodes, channels_per_node); returns same shape.
- class gsnn.gsnn.models.PathwayLatentRegularizer(*args: Any, **kwargs: Any)[source]
Bases:
ModulePer-pathway latent-factor auxiliary loss for GSNN.
For each
ResBlock\(\ell\) and minibatch of size \(B\):Reduce the cached activation \(A_\ell \in \mathbb{R}^{B \times N_{\text{func}} \times C_{pn}}\) to per-node scalars \(s_\ell \in \mathbb{R}^{B \times N_{\text{func}}}\) via \(\phi\).
Compute per-pathway scores \(S_\ell = \mathrm{normalize}(M\, s_\ell^\top)^\top \in \mathbb{R}^{B \times P}\).
Standardize across the batch dimension and compute the member-by-pathway correlation matrix \(C \in \mathbb{R}^{N_{\text{func}} \times P}\).
Add the negative member-side correlation to \(L_{\text{sim}}\), and (if
dissim_pairsis provided) the squared score-correlation of dissimilar pairs to \(L_{\text{dis}}\).
Cost is \(O(L \cdot B \cdot (N_{\text{func}} \cdot C_{pn} + P \cdot N_{\text{func}}))\) per minibatch and parameter count is either zero (
phi='mean') or \(C_{pn}\) (phi='learned').- Parameters:
model (GSNN) – Reference model. Used only for shape introspection (
ResBlocks[0].channel_groups) and to toggle :pyobj:`_store_activations`.pathway_membership (Tensor) – Float tensor of shape
(P, N_func). Real-valued entries are allowed and treated as soft membership weights; binary 0/1 entries give classical hard membership. Rows correspond to pathways, columns to function-node indices in the same order as the model’sfunctionnode names.dissim_pairs (Tensor, optional) –
(M, 2)LongTensor of pathway index pairs whose scores should be encouraged to be uncorrelated. (default:None)lambda_sim (float, optional) – Scaling for the similarity term. (default:
0.1)lambda_dis (float, optional) – Scaling for the dissimilarity term. (default:
0.0)phi (str or nn.Module, optional) – Per-node scalar reduction.
'mean'averages across channels (no parameters).'learned'uses a singlenn.LinearprojectionLinear(C_pn, 1)shared across layers. A customnn.Modulemay also be passed; it must accept a tensor of shape(B, N_func, C_pn)and return(B, N_func, 1)or(B, N_func). (default:'mean')eps (float, optional) – Small constant for numerical stability in the standardization step. (default:
1e-6)
- disable(model)[source]
Disable activation caching and clear cached tensors on
model.Useful at evaluation time to avoid retaining graph references. Returns
selffor chaining.
- enable(model)[source]
Enable activation caching on every ResBlock of
model.Idempotent. Called automatically at construction time. Returns
selffor chaining.
- loss(model)[source]
Compute the auxiliary similarity / dissimilarity losses.
Must be called after a training-mode forward pass so that each
ResBlockhas populated its :pyobj:`_last_activation`.- Parameters:
model (GSNN) – The same model passed to
__init__().- Returns:
(L_sim, L_dis)— both already scaled by their respectivelambda_*.L_disis0.0if nodissim_pairswere provided.- Return type:
- class gsnn.gsnn.models.ResBlock(*args: Any, **kwargs: Any)[source]
Bases:
Module- forward(x, batch_params, node_err=None, fn_activity=None)[source]
Implements the forward pass of the residual block.
- The forward pass consists of:
Edge batch normalization (if configured)
Input sparse linear transformation
Optional node error addition
Normalization and nonlinearity
Node masking (if configured)
Output sparse linear transformation
Dropout
Residual connection (if enabled)
- Parameters:
x (Tensor) – Edge features of shape
[batch_size, num_edges]or[batch_size, num_edges, 1].batch_params (tuple) – A tuple containing: - batched_indices_in (Tensor): Batched indices for input sparse linear layer - batched_indices_out (Tensor): Batched indices for output sparse linear layer
node_err (Tensor, optional) – Node-level error terms to be added after input transformation. Shape
[batch_size, num_nodes]. (default:None)fn_activity (Tensor, optional) – Function node activity of shape
[batch_size, num_function_nodes]. Used for computing node activity. (default:None)
- Returns:
Transformed edge features of shape
[batch_size, num_edges].- Return type:
Tensor
Example
>>> # Using the block from the class example >>> x = torch.randn(32, 2) # [batch_size, num_edges] >>> # Create batched indices (normally done by GSNN) >>> batch_in = torch.tensor([[0, 0], [0, 1]]) >>> batch_out = torch.tensor([[0, 1], [0, 0]]) >>> batch_params = (batch_in, batch_out) >>> # Forward pass >>> out = block(x, batch_params) >>> print(out.shape) # [32, 2] >>> # With node errors >>> node_err = torch.randn(32, 1) # [batch_size, num_nodes] >>> out = block(x, batch_params, node_err=node_err) >>> print(out.shape) # [32, 2]
- set_node_mask(mask)[source]
Set a mask to restrict which channels or nodes are active in the computation.
- Parameters:
mask (torch.Tensor) – A boolean mask indicating which positions remain active.
- class gsnn.gsnn.models.SignedMessagePassing(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassingAggregate scalar signals over function-function edges using stored signs (edge weights).
- class gsnn.gsnn.models.SoftmaxGroupNorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleChannel-wise softmax normalized within each channel group (stable softmax via per-group max shift).
- class gsnn.gsnn.models.SparseLinear(*args: Any, **kwargs: Any)[source]
Bases:
ModuleFixed sparsity pattern linear layer; forward is batched message passing on the COO indices.
Modules
|
Applies normalization per individual channel using exponential moving averages. |
|
|
|
A batch-norm style module that: |
|
Applies normalization within each channel group using exponential moving averages. |
|
Layer normalization computed separately within each channel group. |
|
Applies Root Mean Square normalization within each channel group. |
|
Fully-connected baseline: Linear blocks with optional norm, activation, dropout. |
|
Node-wise channel attention. |
|
Small MLP applied independently to each node's channel vector inside a ResBlock. |
Per-pathway latent-factor auxiliary loss for GSNN. |
|
|
|
|
Aggregate scalar signals over function-function edges using stored signs (edge weights). |
|
Channel-wise softmax normalized within each channel group (stable softmax via per-group max shift). |
|
Fixed sparsity pattern linear layer; forward is batched message passing on the COO indices. |