gsnn.models.GSNN
Functions
|
Apply normalization and nonlinearity to the input tensor. |
|
Create batched edge_index/edge_weight tensors for bipartite graphs. |
|
Convert edge-level features back to node-level features, focusing on output nodes. |
|
Build sparse COO indices for the input weight matrix \(W_{in}\). |
|
Build sparse COO indices for the output weight matrix \(W_{out}\). |
|
Compute indexing structures for convolutional (sparse linear) layers. |
|
Convert a heterogeneous GSNN graph into a homogeneous graph representation. |
|
Convert node-level features to edge-level features. |
Classes
|
Applies normalization per individual channel using exponential moving averages. |
|
|
|
A batch-norm style module that: |
|
Applies normalization within each channel group using exponential moving averages. |
|
|
|
Applies Root Mean Square normalization within each channel group. |
|
Node-wise channel attention. |
|
|
|
|
|
|
|
|
|
- class gsnn.models.GSNN.GSNN(*args: Any, **kwargs: Any)[source]
Bases:
Module- forward(x, node_mask=None, edge_mask=None, ret_edge_out=False, e0=None, node_errs=None)[source]
Implements the forward pass of the GSNN model.
The model first converts node features to edge features, then applies a sequence of sparse linear transformations constrained by the graph structure. Each layer consists of:
Input transformation (W_in)
Normalization (optional)
Nonlinearity
Output transformation (W_out)
Residual connection (optional)
- Parameters:
x (Tensor) – Input node features of shape
[batch_size, num_input_nodes].node_mask (Tensor, optional) – Boolean mask for function nodes of shape
[batch_size, num_nodes]. If provided, masks out specific function nodes during computation. (default:None)edge_mask (Tensor, optional) – Boolean mask for edges of shape
[batch_size, num_edges]. If provided, masks out specific edges during computation. (default:None)ret_edge_out (bool, optional) – If set to
True, return edge-level features instead of node-level features. (default:False)e0 (Tensor, optional) – Initial edge features of shape
[batch_size, num_edges]. Used for inferring input errors. (default:None)node_errs (List[Tensor], optional) – List of node errors per layer, each of shape
[batch_size, num_nodes]. Length must match number of layers. (default:None)
- Returns:
If
ret_edge_out=False, returns node-level output features of shape[batch_size, num_output_nodes]. Otherwise, returns edge-level features of shape[batch_size, num_edges].- Return type:
Tensor
Example
>>> # Using the model from the class example >>> x = torch.randn(32, 2) # batch_size=32, num_input_nodes=2 >>> # Basic forward pass >>> out = model(x) >>> print(out.shape) # [32, 1] >>> # Get edge-level features >>> edge_out = model(x, ret_edge_out=True) >>> print(edge_out.shape) # [32, 4] (batch_size, num_edges) >>> # Using masks >>> node_mask = torch.ones(32, 4) # [batch_size, num_nodes] >>> edge_mask = torch.ones(32, 4) # [batch_size, num_edges] >>> out = model(x, node_mask=node_mask, edge_mask=edge_mask) >>> print(out.shape) # [32, 1]
- get_batch_params(B, device)[source]
Retrieves or computes batch-specific indexing parameters for sparse linear layers.
This method caches the batch parameters to avoid recomputing them for the same batch size. The parameters are used to efficiently perform batched sparse matrix operations.
- Parameters:
B (int) – Batch size.
device (torch.device) – Device on which to place the computed parameters.
- Returns:
- A tuple containing:
batched_indices_in (Tensor): Batched indices for input sparse linear layer
batched_indices_out (Tensor): Batched indices for output sparse linear layer
- Return type:
Example
>>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3) >>> # Get batch parameters for batch size 32 >>> batch_params = model.get_batch_params(32, torch.device('cuda')) >>> # Parameters are cached for subsequent calls >>> same_params = model.get_batch_params(32, torch.device('cuda')) >>> # Different batch size triggers recomputation >>> new_params = model.get_batch_params(64, torch.device('cuda'))
- get_node_attention(x)[source]
Return per-layer node-level attention weights.
- Parameters:
x (Tensor (B, num_input_nodes)) – Input features; typically supply a single sample (B=1).
- Returns:
Mapping from node name to a tensor of shape (L, B) with attention weights per layer (L) and batch element (B).
- Return type:
Dict[str, Tensor]
- prune(threshold=0.01, verbose=False)[source]
Prunes the model by removing channels with small weights.
This method removes channels whose maximum absolute weight value across all layers is below the specified threshold. This can significantly reduce model size while maintaining performance. Remember to reinitialize the optimizer after pruning if using during training.
- Parameters:
- Returns:
Number of parameters removed by pruning.
- Return type:
Example
>>> # Create a model with 16 channels per function node >>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3) >>> # Train the model... >>> # Prune channels with small weights >>> removed_params = model.prune(threshold=1e-2, verbose=True) >>> print(f'Removed {removed_params} parameters')
- class gsnn.models.GSNN.NodeAttention(*args: Any, **kwargs: Any)[source]
Bases:
ModuleNode-wise channel attention.
The layer learns a single scalar attention coefficient (alpha_{b,n}) per node n for every sample in the batch b. The coefficient is obtained by first aggregating the (optionally weighted) hidden channels that belong to the node and then normalising the aggregated scores across all nodes with a sigmoid gates per node (no cross-node normalization). The resulting attention weights can be:
Interpreted - (alpha_{b,n}) tells how important node n was for the current forward pass.
Applied - the coefficients are broadcast back to the individual channels that originated from the node and multiplied with the original activations, producing an attention-modulated output.
- Parameters:
channel_groups (Sequence[int] or Tensor) – A 1-D list/array mapping global channel index → node index. Length equals the total number of hidden channels across all nodes.
dropout (float, optional (default=0.0)) – Dropout probability applied to the node-level attention weights.
temperature (float, optional (default=1.0)) – Softmax temperature. Lower values produce sharper distributions.
Examples
>>> # Suppose we have 2 nodes with 3 channels each (total 6 channels) >>> ch_groups = [0, 0, 0, 1, 1, 1] >>> attn = NodeAttention(ch_groups, dropout=0.1) >>> x = torch.randn(8, 6) # (batch=8, channels=6) >>> out, alpha = attn(x, return_alpha=True) >>> out.shape # same shape as input torch.Size([8, 6]) >>> alpha.shape # one scalar per node torch.Size([8, 2])
- forward(x: torch.Tensor, *, return_alpha: bool = False)[source]
Apply node attention.
- Parameters:
x (Tensor of shape (B, C)) – Input activations ordered so that channels belonging to the same node are indexed according to channel_groups.
return_alpha (bool, optional (default=False)) – If True, the method returns a tuple
(out, alpha)wherealphais the attention matrix of shape (B, n_nodes).
- Returns:
The attention-modulated activations (and, optionally, the node coefficients).
- Return type:
Tensor or Tuple[Tensor, Tensor]
- class gsnn.models.GSNN.NodeMLP(*args: Any, **kwargs: Any)[source]
Bases:
Module- forward(x: torch.Tensor) torch.Tensor[source]
- class gsnn.models.GSNN.ResBlock(*args: Any, **kwargs: Any)[source]
Bases:
Module- forward(x, batch_params, node_err=None)[source]
Implements the forward pass of the residual block.
- The forward pass consists of:
Edge batch normalization (if configured)
Input sparse linear transformation
Optional node error addition
Normalization and nonlinearity
Node masking (if configured)
Output sparse linear transformation
Dropout
Residual connection (if enabled)
- Parameters:
x (Tensor) – Edge features of shape
[batch_size, num_edges]or[batch_size, num_edges, 1].batch_params (tuple) – A tuple containing: - batched_indices_in (Tensor): Batched indices for input sparse linear layer - batched_indices_out (Tensor): Batched indices for output sparse linear layer
node_err (Tensor, optional) – Node-level error terms to be added after input transformation. Shape
[batch_size, num_nodes]. (default:None)
- Returns:
Transformed edge features of shape
[batch_size, num_edges].- Return type:
Tensor
Example
>>> # Using the block from the class example >>> x = torch.randn(32, 2) # [batch_size, num_edges] >>> # Create batched indices (normally done by GSNN) >>> batch_in = torch.tensor([[0, 0], [0, 1]]) >>> batch_out = torch.tensor([[0, 1], [0, 0]]) >>> batch_params = (batch_in, batch_out) >>> # Forward pass >>> out = block(x, batch_params) >>> print(out.shape) # [32, 2] >>> # With node errors >>> node_err = torch.randn(32, 1) # [batch_size, num_nodes] >>> out = block(x, batch_params, node_err=node_err) >>> print(out.shape) # [32, 2]
- set_node_mask(mask)[source]
Set a mask to restrict which channels or nodes are active in the computation.
- Parameters:
mask (torch.Tensor) – A boolean mask indicating which positions remain active.
- class gsnn.models.GSNN.SignedMessagePassing(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
- gsnn.models.GSNN.apply_norm_and_nonlin(norm, nonlin, out, norm_first)[source]
Apply normalization and nonlinearity to the input tensor.
- Parameters:
- Returns:
The transformed tensor.
- Return type:
Tensor
Example
>>> norm = torch.nn.BatchNorm1d(32) >>> nonlin = torch.nn.ReLU() >>> x = torch.randn(16, 32) # [batch_size, num_features] >>> # Apply normalization first >>> out = apply_norm_and_nonlin(norm, nonlin, x, norm_first=True) >>> print(out.shape) # [16, 32]
- gsnn.models.GSNN.edge2node(x, edge_index, output_node_mask)[source]
Convert edge-level features back to node-level features, focusing on output nodes.
Typically, output nodes should be designed to have an in-degree of 1, however, in the case of multiple edges per output node, the output features are summed and normalized by the square root of the in-degree.
- Parameters:
x (Tensor) – Edge features of shape
[batch_size, num_edges].edge_index (Tensor) – Edge indices of shape
[2, num_edges].output_node_mask (Tensor) – Boolean mask of shape
[num_nodes]indicating output nodes.
- Returns:
Node features of shape
[batch_size, num_output_nodes].- Return type:
Tensor
Example
>>> x = torch.randn(32, 3) # [batch_size, num_edges] >>> edge_index = torch.tensor([[0, 1, 1], [2, 2, 3]]) # 3 edges >>> output_mask = torch.tensor([0, 0, 1, 1]) # Nodes 2,3 are outputs >>> node_features = edge2node(x, edge_index, output_mask) >>> print(node_features.shape) # [32, 2]
- gsnn.models.GSNN.get_Win_indices(edge_index, channels, function_nodes)[source]
Build sparse COO indices for the input weight matrix \(W_{in}\).
- Parameters:
edge_index (Tensor) – Homogeneous edge index of shape
[2, num_edges].channels (int or Tensor) – If int, every function node gets the same number of hidden channels. If 1-D tensor/array, it must contain the per-node channel count of length
num_nodes.function_nodes (Tensor) – Index list of nodes that represent functions.
- Returns:
- A tuple containing:
indices (Tensor): COO indices of shape
[2, nnz]for sparse tensor constructionchannel_count (numpy.ndarray): Per-node channel counts for later reuse
- Return type:
Example
>>> edge_index = torch.tensor([[0, 1], [1, 0]]) # 2 edges >>> channels = 3 # 3 channels per function node >>> function_nodes = torch.tensor([0]) # Node 0 is a function node >>> indices, counts = get_Win_indices(edge_index, channels, function_nodes) >>> print(indices.shape) # [2, 6] (2 edges * 3 channels) >>> print(counts) # [3, 0] (3 channels for node 0, 0 for node 1)
- gsnn.models.GSNN.get_Wout_indices(edge_index, function_nodes, channels)[source]
Build sparse COO indices for the output weight matrix \(W_{out}\).
- Parameters:
edge_index (Tensor) – Homogeneous edge index of shape
[2, num_edges].function_nodes (Tensor) – Index list of nodes that represent functions.
channels (numpy.ndarray) – Array indicating the number of channels for each node.
- Returns:
COO indices of shape
[2, nnz]for sparse tensor construction.- Return type:
Tensor
Example
>>> edge_index = torch.tensor([[0, 1], [1, 0]]) # 2 edges >>> function_nodes = torch.tensor([0]) # Node 0 is a function node >>> channels = np.array([3, 0]) # 3 channels for node 0, 0 for node 1 >>> indices = get_Wout_indices(edge_index, function_nodes, channels) >>> print(indices.shape) # [2, 6] (3 channels * 2 edges)
- gsnn.models.GSNN.get_conv_indices(edge_index, channels, function_nodes)[source]
Compute indexing structures for convolutional (sparse linear) layers.
- Parameters:
edge_index (Tensor) – Homogeneous edge indices of shape
[2, num_edges].channels (int) – Number of channels per function node.
function_nodes (Tensor) – Indices of function nodes.
- Returns:
- A tuple containing:
w_in_indices (Tensor): Indexing for \(W_{in}\)
w_out_indices (Tensor): Indexing for \(W_{out}\)
w_in_size (tuple): Size specification for \(W_{in}\)
w_out_size (tuple): Size specification for \(W_{out}\)
channel_groups (List[int]): List mapping each channel to its node
- Return type:
Example
>>> edge_index = torch.tensor([[0, 1], [1, 0]]) # 2 edges >>> channels = 3 # 3 channels per function node >>> function_nodes = torch.tensor([0]) # Node 0 is a function node >>> indices = get_conv_indices(edge_index, channels, function_nodes) >>> print(len(indices)) # 5 (w_in_indices, w_out_indices, sizes, groups)
- gsnn.models.GSNN.hetero2homo(edge_index_dict, node_names_dict, edge_weight_dict=None)[source]
Convert a heterogeneous GSNN graph into a homogeneous graph representation.
- The GSNN pipeline distinguishes three edge types:
(‘input’, ‘to’, ‘function’)
(‘function’, ‘to’, ‘function’)
(‘function’, ‘to’, ‘output’)
This function stacks these edge sets into one homogeneous graph and returns boolean masks that let you recover the original node semantics.
- Parameters:
edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge-type mapping where each value is a
LongTensorwith shape[2, num_edges_of_type].node_names_dict (Dict[str, List[str]]) – Mapping of node types (‘input’, ‘function’, ‘output’) to their respective node names.
edge_weight_dict (Dict[Tuple[str, str, str], Tensor]) – Dictionary mapping edge types to edge weights. Expected keys are (‘input’, ‘to’, ‘function’), (‘function’, ‘to’, ‘function’), and (‘function’, ‘to’, ‘output’). Values should be tensors of shape
[num_edges]. (default:None)
- Returns:
- A tuple containing:
edge_index (Tensor): Homogeneous edge indices of shape
[2, num_edges]input_mask (Tensor): Boolean mask for input nodes of shape
[num_nodes]output_mask (Tensor): Boolean mask for output nodes of shape
[num_nodes]num_nodes (int): Total number of nodes in the homogeneous graph
homo_names (List[str]): Node names in homogeneous ordering
edge_weight (Optional[Tensor]): Homogeneous edge weights of shape
[num_edges], orNoneifedge_weight_dictwasNone.
- Return type:
Example
>>> edge_index_dict = { ... ('input', 'to', 'function'): torch.tensor([[0, 1], [0, 0]]), ... ('function', 'to', 'function'): torch.tensor([[0], [0]]), ... ('function', 'to', 'output'): torch.tensor([[0], [0]]) ... } >>> node_names_dict = { ... 'input': ['in1', 'in2'], ... 'function': ['func1'], ... 'output': ['out1'] ... } >>> edge_index, in_mask, out_mask, n_nodes, names = hetero2homo( ... edge_index_dict, node_names_dict ... ) >>> print(edge_index.shape) # [2, 4] >>> print(in_mask.sum()) # 2 (number of input nodes) >>> print(out_mask.sum()) # 1 (number of output nodes)
- gsnn.models.GSNN.node2edge(x, edge_index)[source]
Convert node-level features to edge-level features. Every out-going edge receives the feature of the source node.
- Parameters:
x (Tensor) – Node features of shape
[batch_size, num_nodes].edge_index (Tensor) – Edge indices of shape
[2, num_edges].
- Returns:
Edge features of shape
[batch_size, num_edges].- Return type:
Tensor
Example
>>> x = torch.randn(32, 4) # [batch_size, num_nodes] >>> edge_index = torch.tensor([[0, 1], [1, 2]]) # 2 edges >>> edge_features = node2edge(x, edge_index) >>> print(edge_features.shape) # [32, 2]