gsnn.models.GSNN

class gsnn.models.GSNN(*args: Any, **kwargs: Any)[source]

Bases: Module

__init__(edge_index_dict, node_names_dict, channels, layers, dropout=0.0, nonlin=torch.nn.ELU, bias=True, share_layers=True, add_function_self_edges=True, norm='layer', init='degree_normalized', verbose=False, edge_channels=1, checkpoint=False, residual=True, norm_first=True, node_attn=False, attn_mlp_hidden=16, node_mlp=False, node_mlp_hidden=16, node_activity=False, node_activity_hidden=16, node_activity_mode='per-node', node_activity_dim=1, node_activity_temperature=1.0, node_activity_dropout=0.0, edge_weight_dict=None)[source]

Graph Structured Neural Network (GSNN) that constrains neural network architecture using a predefined graph structure. Unlike traditional GNNs that learn from graph structure, GSNN uses the graph to constrain which variables can directly influence each other. The model operates on edge features rather than node features and supports cyclic graphs.

The architecture uses three types of nodes:
  1. Input nodes: Represent observed variables

  2. Function nodes: Represent latent variables parameterized by neural networks

  3. Output nodes: Represent target variables

Only function nodes are trainable; input and output nodes pass/receive information unchanged.

Parameters:
  • edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Dictionary mapping edge types to edge indices. Expected keys are (‘input’, ‘to’, ‘function’), (‘function’, ‘to’, ‘function’), and (‘function’, ‘to’, ‘output’). Values should be tensors of shape [2, num_edges].

  • node_names_dict (Dict[str, List[str]]) – Dictionary mapping node types (‘input’, ‘function’, ‘output’) to their respective node names.

  • channels (int) – Number of hidden channels per function node.

  • layers (int) – Number of sequential sparse linear layers to propagate information across the graph.

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • nonlin (torch.nn.Module, optional) – Activation function. (default: torch.nn.ELU)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • share_layers (bool, optional) – If set to True, reuse layer parameters across all layers. (default: True)

  • add_function_self_edges (bool, optional) – If set to True, add self-connections to function nodes. (default: True)

  • norm (str, optional) – Normalization type ('layer', 'batch', 'softmax', 'groupbatch', 'edgebatch', 'rms', 'ema', 'channelema' or 'none'). (default: 'groupbatch')

  • init (str, optional) – Weight initialization strategy ('xavier' or 'kaiming'). (default: 'xavier')

  • verbose (bool, optional) – If set to True, print debugging information. (default: False)

  • edge_channels (int, optional) – Number of latent edge feature channels to replicate. (default: 1)

  • checkpoint (bool, optional) – If set to True, use gradient checkpointing to reduce memory usage. (default: False)

  • residual (bool, optional) – If set to True, add residual connections. (default: True)

  • norm_first (bool, optional) – If set to True, apply normalization before nonlinearity. (default: True)

  • node_attn (bool, optional) – If set to True, apply node attention. (default: False)

  • attn_mlp_hidden (int, optional) – Hidden dimension of the attention MLP.

  • node_mlp (bool, optional) – If set to True, apply additional MLP processing per node to enhance representational capacity while maintaining graph structure constraints. (default: True)

  • node_mlp_hidden (int, optional) – Hidden dimension size for the node MLP when enabled. (default: 128)

  • node_activity (bool, optional) – If set to True, enable per-function-node gating driven by external features (e.g., mutation/expression status). The gate is computed once per forward pass and broadcast to every channel of the corresponding node at every layer. (default: False)

  • node_activity_hidden (int, optional) – Hidden dimension of the node-activity MLP. (default: 16)

  • node_activity_mode (str, optional) – Mode of node activity computation ('per-node', 'per-channel'). If 'per-node', the node activity is computed for each function node independently. If 'per-channel', the node activity is computed for each channel of the function node. (default: 'per-node'). In both cases the node activity function is shared across all function nodes and layers.

  • node_activity_dim (int, optional) – Number of external feature channels per function node expected as input to the node-activity MLP. When 1, the user may pass x_fn as [B, Nf] and it will be unsqueezed internally; otherwise x_fn must be [B, Nf, node_activity_dim]. (default: 1)

  • node_activity_temperature (float, optional) – Sigmoid temperature applied to the node-activity logits. Lower values produce sharper (closer to 0/1) gates, higher values produce softer (closer to 0.5) gates. (default: 1.0)

  • node_activity_dropout (float, optional) – Dropout probability applied to the node-activity MLP. (default: 0.0)

  • edge_weight_dict (Dict[Tuple[str, str, str], Tensor], optional) – Dictionary mapping edge types to edge weights. Expected keys are (‘input’, ‘to’, ‘function’), (‘function’, ‘to’, ‘function’), and (‘function’, ‘to’, ‘output’). Values should be tensors of shape [num_edges]. (default: None)

Example

>>> # Define a simple graph with 2 input nodes, 1 function node, and 1 output node
>>> edge_index_dict = {
...     ('input', 'to', 'function'): torch.tensor([[0, 1], [0, 0]]),  # 2 input edges
...     ('function', 'to', 'function'): torch.tensor([[0], [0]]),     # 1 self edge
...     ('function', 'to', 'output'): torch.tensor([[0], [0]])        # 1 output edge
... }
>>> node_names_dict = {
...     'input': ['in1', 'in2'],
...     'function': ['func1'],
...     'output': ['out1']
... }
>>> model = GSNN(
...     edge_index_dict=edge_index_dict,
...     node_names_dict=node_names_dict,
...     channels=16,
...     layers=3
... )
>>> x = torch.randn(32, 2)  # batch_size=32, num_input_nodes=2
>>> out = model(x)
>>> print(out.shape)  # [32, 1] (batch_size, num_output_nodes)

Methods

__init__(edge_index_dict, node_names_dict, ...)

Graph Structured Neural Network (GSNN) that constrains neural network architecture using a predefined graph structure.

forward(x[, node_mask, edge_mask, ...])

Implements the forward pass of the GSNN model.

get_batch_params(B, device)

Retrieves or computes batch-specific indexing parameters for sparse linear layers.

get_node_activations(x[, agg, inference])

get_node_attention(x)

Return per-layer node-level attention weights.

prune([threshold, verbose])

Prunes the model by removing channels with small weights.

forward(x, node_mask=None, edge_mask=None, ret_edge_out=False, e0=None, node_errs=None, x_fn=None)[source]

Implements the forward pass of the GSNN model.

The model first converts node features to edge features, then applies a sequence of sparse linear transformations constrained by the graph structure. Each layer consists of:

  1. Input transformation (W_in)

  2. Normalization (optional)

  3. Nonlinearity

  4. Output transformation (W_out)

  5. Residual connection (optional)

Parameters:
  • x (Tensor) – Input node features of shape [batch_size, num_input_nodes].

  • node_mask (Tensor, optional) – Boolean mask for function nodes of shape [batch_size, num_nodes]. If provided, masks out specific function nodes during computation. (default: None)

  • edge_mask (Tensor, optional) – Boolean mask for edges of shape [batch_size, num_edges]. If provided, masks out specific edges during computation. (default: None)

  • ret_edge_out (bool, optional) – If set to True, return edge-level features instead of node-level features. (default: False)

  • e0 (Tensor, optional) – Initial edge features of shape [batch_size, num_edges]. Used for inferring input errors. (default: None)

  • node_errs (List[Tensor], optional) – List of node errors per layer, each of shape [batch_size, num_nodes]. Length must match number of layers. (default: None)

  • x_fn (Tensor, optional) – Function node features of shape [batch_size, num_function_nodes]. Used for computing node activity. (default: None)

Returns:

If ret_edge_out=False, returns node-level output features of shape [batch_size, num_output_nodes]. Otherwise, returns edge-level features of shape [batch_size, num_edges].

Return type:

Tensor

Example

>>> # Using the model from the class example
>>> x = torch.randn(32, 2)  # batch_size=32, num_input_nodes=2
>>> # Basic forward pass
>>> out = model(x)
>>> print(out.shape)  # [32, 1]
>>> # Get edge-level features
>>> edge_out = model(x, ret_edge_out=True)
>>> print(edge_out.shape)  # [32, 4] (batch_size, num_edges)
>>> # Using masks
>>> node_mask = torch.ones(32, 4)  # [batch_size, num_nodes]
>>> edge_mask = torch.ones(32, 4)  # [batch_size, num_edges]
>>> out = model(x, node_mask=node_mask, edge_mask=edge_mask)
>>> print(out.shape)  # [32, 1]
get_batch_params(B, device)[source]

Retrieves or computes batch-specific indexing parameters for sparse linear layers.

This method caches the batch parameters to avoid recomputing them for the same batch size. The parameters are used to efficiently perform batched sparse matrix operations.

Parameters:
  • B (int) – Batch size.

  • device (torch.device) – Device on which to place the computed parameters.

Returns:

A tuple containing:
  • batched_indices_in (Tensor): Batched indices for input sparse linear layer

  • batched_indices_out (Tensor): Batched indices for output sparse linear layer

Return type:

tuple

Example

>>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3)
>>> # Get batch parameters for batch size 32
>>> batch_params = model.get_batch_params(32, torch.device('cuda'))
>>> # Parameters are cached for subsequent calls
>>> same_params = model.get_batch_params(32, torch.device('cuda'))
>>> # Different batch size triggers recomputation
>>> new_params = model.get_batch_params(64, torch.device('cuda'))
get_node_activations(x, agg='sum', inference=True)[source]
get_node_attention(x)[source]

Return per-layer node-level attention weights.

Parameters:

x (Tensor (B, num_input_nodes)) – Input features; typically supply a single sample (B=1).

Returns:

Mapping from node name to a tensor of shape (L, B) with attention weights per layer (L) and batch element (B).

Return type:

Dict[str, Tensor]

prune(threshold=0.01, verbose=False)[source]

Prunes the model by removing channels with small weights.

This method removes channels whose maximum absolute weight value across all layers is below the specified threshold. This can significantly reduce model size while maintaining performance. Remember to reinitialize the optimizer after pruning if using during training.

Parameters:
  • threshold (float, optional) – The threshold below which weights are considered insignificant. (default: 1e-2)

  • verbose (bool, optional) – If set to True, print pruning statistics. (default: False)

Returns:

Number of parameters removed by pruning.

Return type:

int

Example

>>> # Create a model with 16 channels per function node
>>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3)
>>> # Train the model...
>>> # Prune channels with small weights
>>> removed_params = model.prune(threshold=1e-2, verbose=True)
>>> print(f'Removed {removed_params} parameters')