gsnn.models.GSNN
- class gsnn.models.GSNN(*args: Any, **kwargs: Any)[source]
Bases:
Module- __init__(edge_index_dict, node_names_dict, channels, layers, dropout=0.0, nonlin=torch.nn.ELU, bias=True, share_layers=True, add_function_self_edges=True, norm='layer', init='degree_normalized', verbose=False, edge_channels=1, checkpoint=False, residual=True, norm_first=True, node_attn=False, attn_mlp_hidden=16, node_mlp=False, node_mlp_hidden=16, node_activity=False, node_activity_hidden=16, node_activity_mode='per-node', node_activity_dim=1, node_activity_temperature=1.0, node_activity_dropout=0.0, edge_weight_dict=None)[source]
Graph Structured Neural Network (GSNN) that constrains neural network architecture using a predefined graph structure. Unlike traditional GNNs that learn from graph structure, GSNN uses the graph to constrain which variables can directly influence each other. The model operates on edge features rather than node features and supports cyclic graphs.
- The architecture uses three types of nodes:
Input nodes: Represent observed variables
Function nodes: Represent latent variables parameterized by neural networks
Output nodes: Represent target variables
Only function nodes are trainable; input and output nodes pass/receive information unchanged.
- Parameters:
edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Dictionary mapping edge types to edge indices. Expected keys are (‘input’, ‘to’, ‘function’), (‘function’, ‘to’, ‘function’), and (‘function’, ‘to’, ‘output’). Values should be tensors of shape
[2, num_edges].node_names_dict (Dict[str, List[str]]) – Dictionary mapping node types (‘input’, ‘function’, ‘output’) to their respective node names.
channels (int) – Number of hidden channels per function node.
layers (int) – Number of sequential sparse linear layers to propagate information across the graph.
dropout (float, optional) – Dropout probability. (default:
0.)nonlin (torch.nn.Module, optional) – Activation function. (default:
torch.nn.ELU)bias (bool, optional) – If set to
False, the layer will not learn an additive bias. (default:True)share_layers (bool, optional) – If set to
True, reuse layer parameters across all layers. (default:True)add_function_self_edges (bool, optional) – If set to
True, add self-connections to function nodes. (default:True)norm (str, optional) – Normalization type (
'layer','batch','softmax','groupbatch','edgebatch','rms','ema','channelema'or'none'). (default:'groupbatch')init (str, optional) – Weight initialization strategy (
'xavier'or'kaiming'). (default:'xavier')verbose (bool, optional) – If set to
True, print debugging information. (default:False)edge_channels (int, optional) – Number of latent edge feature channels to replicate. (default:
1)checkpoint (bool, optional) – If set to
True, use gradient checkpointing to reduce memory usage. (default:False)residual (bool, optional) – If set to
True, add residual connections. (default:True)norm_first (bool, optional) – If set to
True, apply normalization before nonlinearity. (default:True)node_attn (bool, optional) – If set to
True, apply node attention. (default:False)attn_mlp_hidden (int, optional) – Hidden dimension of the attention MLP.
node_mlp (bool, optional) – If set to
True, apply additional MLP processing per node to enhance representational capacity while maintaining graph structure constraints. (default:True)node_mlp_hidden (int, optional) – Hidden dimension size for the node MLP when enabled. (default:
128)node_activity (bool, optional) – If set to
True, enable per-function-node gating driven by external features (e.g., mutation/expression status). The gate is computed once per forward pass and broadcast to every channel of the corresponding node at every layer. (default:False)node_activity_hidden (int, optional) – Hidden dimension of the node-activity MLP. (default:
16)node_activity_mode (str, optional) – Mode of node activity computation (
'per-node','per-channel'). If'per-node', the node activity is computed for each function node independently. If'per-channel', the node activity is computed for each channel of the function node. (default:'per-node'). In both cases the node activity function is shared across all function nodes and layers.node_activity_dim (int, optional) – Number of external feature channels per function node expected as input to the node-activity MLP. When
1, the user may passx_fnas[B, Nf]and it will be unsqueezed internally; otherwisex_fnmust be[B, Nf, node_activity_dim]. (default:1)node_activity_temperature (float, optional) – Sigmoid temperature applied to the node-activity logits. Lower values produce sharper (closer to 0/1) gates, higher values produce softer (closer to 0.5) gates. (default:
1.0)node_activity_dropout (float, optional) – Dropout probability applied to the node-activity MLP. (default:
0.0)edge_weight_dict (Dict[Tuple[str, str, str], Tensor], optional) – Dictionary mapping edge types to edge weights. Expected keys are (‘input’, ‘to’, ‘function’), (‘function’, ‘to’, ‘function’), and (‘function’, ‘to’, ‘output’). Values should be tensors of shape
[num_edges]. (default:None)
Example
>>> # Define a simple graph with 2 input nodes, 1 function node, and 1 output node >>> edge_index_dict = { ... ('input', 'to', 'function'): torch.tensor([[0, 1], [0, 0]]), # 2 input edges ... ('function', 'to', 'function'): torch.tensor([[0], [0]]), # 1 self edge ... ('function', 'to', 'output'): torch.tensor([[0], [0]]) # 1 output edge ... } >>> node_names_dict = { ... 'input': ['in1', 'in2'], ... 'function': ['func1'], ... 'output': ['out1'] ... } >>> model = GSNN( ... edge_index_dict=edge_index_dict, ... node_names_dict=node_names_dict, ... channels=16, ... layers=3 ... ) >>> x = torch.randn(32, 2) # batch_size=32, num_input_nodes=2 >>> out = model(x) >>> print(out.shape) # [32, 1] (batch_size, num_output_nodes)
Methods
__init__(edge_index_dict, node_names_dict, ...)Graph Structured Neural Network (GSNN) that constrains neural network architecture using a predefined graph structure.
forward(x[, node_mask, edge_mask, ...])Implements the forward pass of the GSNN model.
get_batch_params(B, device)Retrieves or computes batch-specific indexing parameters for sparse linear layers.
get_node_activations(x[, agg, inference])get_node_attention(x)Return per-layer node-level attention weights.
prune([threshold, verbose])Prunes the model by removing channels with small weights.
- forward(x, node_mask=None, edge_mask=None, ret_edge_out=False, e0=None, node_errs=None, x_fn=None)[source]
Implements the forward pass of the GSNN model.
The model first converts node features to edge features, then applies a sequence of sparse linear transformations constrained by the graph structure. Each layer consists of:
Input transformation (W_in)
Normalization (optional)
Nonlinearity
Output transformation (W_out)
Residual connection (optional)
- Parameters:
x (Tensor) – Input node features of shape
[batch_size, num_input_nodes].node_mask (Tensor, optional) – Boolean mask for function nodes of shape
[batch_size, num_nodes]. If provided, masks out specific function nodes during computation. (default:None)edge_mask (Tensor, optional) – Boolean mask for edges of shape
[batch_size, num_edges]. If provided, masks out specific edges during computation. (default:None)ret_edge_out (bool, optional) – If set to
True, return edge-level features instead of node-level features. (default:False)e0 (Tensor, optional) – Initial edge features of shape
[batch_size, num_edges]. Used for inferring input errors. (default:None)node_errs (List[Tensor], optional) – List of node errors per layer, each of shape
[batch_size, num_nodes]. Length must match number of layers. (default:None)x_fn (Tensor, optional) – Function node features of shape
[batch_size, num_function_nodes]. Used for computing node activity. (default:None)
- Returns:
If
ret_edge_out=False, returns node-level output features of shape[batch_size, num_output_nodes]. Otherwise, returns edge-level features of shape[batch_size, num_edges].- Return type:
Tensor
Example
>>> # Using the model from the class example >>> x = torch.randn(32, 2) # batch_size=32, num_input_nodes=2 >>> # Basic forward pass >>> out = model(x) >>> print(out.shape) # [32, 1] >>> # Get edge-level features >>> edge_out = model(x, ret_edge_out=True) >>> print(edge_out.shape) # [32, 4] (batch_size, num_edges) >>> # Using masks >>> node_mask = torch.ones(32, 4) # [batch_size, num_nodes] >>> edge_mask = torch.ones(32, 4) # [batch_size, num_edges] >>> out = model(x, node_mask=node_mask, edge_mask=edge_mask) >>> print(out.shape) # [32, 1]
- get_batch_params(B, device)[source]
Retrieves or computes batch-specific indexing parameters for sparse linear layers.
This method caches the batch parameters to avoid recomputing them for the same batch size. The parameters are used to efficiently perform batched sparse matrix operations.
- Parameters:
B (int) – Batch size.
device (torch.device) – Device on which to place the computed parameters.
- Returns:
- A tuple containing:
batched_indices_in (Tensor): Batched indices for input sparse linear layer
batched_indices_out (Tensor): Batched indices for output sparse linear layer
- Return type:
Example
>>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3) >>> # Get batch parameters for batch size 32 >>> batch_params = model.get_batch_params(32, torch.device('cuda')) >>> # Parameters are cached for subsequent calls >>> same_params = model.get_batch_params(32, torch.device('cuda')) >>> # Different batch size triggers recomputation >>> new_params = model.get_batch_params(64, torch.device('cuda'))
- get_node_attention(x)[source]
Return per-layer node-level attention weights.
- Parameters:
x (Tensor (B, num_input_nodes)) – Input features; typically supply a single sample (B=1).
- Returns:
Mapping from node name to a tensor of shape (L, B) with attention weights per layer (L) and batch element (B).
- Return type:
Dict[str, Tensor]
- prune(threshold=0.01, verbose=False)[source]
Prunes the model by removing channels with small weights.
This method removes channels whose maximum absolute weight value across all layers is below the specified threshold. This can significantly reduce model size while maintaining performance. Remember to reinitialize the optimizer after pruning if using during training.
- Parameters:
- Returns:
Number of parameters removed by pruning.
- Return type:
Example
>>> # Create a model with 16 channels per function node >>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3) >>> # Train the model... >>> # Prune channels with small weights >>> removed_params = model.prune(threshold=1e-2, verbose=True) >>> print(f'Removed {removed_params} parameters')