gsnn.gsnn.models.GSNN

Functions

apply_norm_and_nonlin(norm, nonlin, out, ...)

Apply normalization and nonlinearity to the input tensor.

batch_graphs(N, M, edge_index, B, device)

Create batched edge_index/edge_weight tensors for bipartite graphs.

edge2node(x, edge_index, output_node_mask)

Convert edge-level features back to node-level features, focusing on output nodes.

get_Win_indices(edge_index, channels, ...)

Build sparse COO indices for the input weight matrix \(W_{in}\).

get_Wout_indices(edge_index, function_nodes, ...)

Build sparse COO indices for the output weight matrix \(W_{out}\).

get_conv_indices(edge_index, channels, ...)

Compute indexing structures for convolutional (sparse linear) layers.

hetero2homo(edge_index_dict, node_names_dict)

Convert a heterogeneous GSNN graph into a homogeneous graph representation.

node2edge(x, edge_index)

Convert node-level features to edge-level features.

Classes

ChannelEMANorm(*args, **kwargs)

Applies normalization per individual channel using exponential moving averages.

GSNN(*args, **kwargs)

GroupBatchNorm(*args, **kwargs)

A batch-norm style module that:

GroupEMANorm(*args, **kwargs)

Applies normalization within each channel group using exponential moving averages.

GroupLayerNorm(*args, **kwargs)

GroupRMSNorm(*args, **kwargs)

Applies Root Mean Square normalization within each channel group.

NodeAttention(*args, **kwargs)

Node-wise channel attention.

NodeMLP(*args, **kwargs)

ResBlock(*args, **kwargs)

SignedMessagePassing(*args, **kwargs)

SoftmaxGroupNorm(*args, **kwargs)

SparseLinear(*args, **kwargs)

class gsnn.gsnn.models.GSNN.GSNN(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x, node_mask=None, edge_mask=None, ret_edge_out=False, e0=None, node_errs=None)[source]

Implements the forward pass of the GSNN model.

The model first converts node features to edge features, then applies a sequence of sparse linear transformations constrained by the graph structure. Each layer consists of:

  1. Input transformation (W_in)

  2. Normalization (optional)

  3. Nonlinearity

  4. Output transformation (W_out)

  5. Residual connection (optional)

Parameters:
  • x (Tensor) – Input node features of shape [batch_size, num_input_nodes].

  • node_mask (Tensor, optional) – Boolean mask for function nodes of shape [batch_size, num_nodes]. If provided, masks out specific function nodes during computation. (default: None)

  • edge_mask (Tensor, optional) – Boolean mask for edges of shape [batch_size, num_edges]. If provided, masks out specific edges during computation. (default: None)

  • ret_edge_out (bool, optional) – If set to True, return edge-level features instead of node-level features. (default: False)

  • e0 (Tensor, optional) – Initial edge features of shape [batch_size, num_edges]. Used for inferring input errors. (default: None)

  • node_errs (List[Tensor], optional) – List of node errors per layer, each of shape [batch_size, num_nodes]. Length must match number of layers. (default: None)

Returns:

If ret_edge_out=False, returns node-level output features of shape [batch_size, num_output_nodes]. Otherwise, returns edge-level features of shape [batch_size, num_edges].

Return type:

Tensor

Example

>>> # Using the model from the class example
>>> x = torch.randn(32, 2)  # batch_size=32, num_input_nodes=2
>>> # Basic forward pass
>>> out = model(x)
>>> print(out.shape)  # [32, 1]
>>> # Get edge-level features
>>> edge_out = model(x, ret_edge_out=True)
>>> print(edge_out.shape)  # [32, 4] (batch_size, num_edges)
>>> # Using masks
>>> node_mask = torch.ones(32, 4)  # [batch_size, num_nodes]
>>> edge_mask = torch.ones(32, 4)  # [batch_size, num_edges]
>>> out = model(x, node_mask=node_mask, edge_mask=edge_mask)
>>> print(out.shape)  # [32, 1]
get_batch_params(B, device)[source]

Retrieves or computes batch-specific indexing parameters for sparse linear layers.

This method caches the batch parameters to avoid recomputing them for the same batch size. The parameters are used to efficiently perform batched sparse matrix operations.

Parameters:
  • B (int) – Batch size.

  • device (torch.device) – Device on which to place the computed parameters.

Returns:

A tuple containing:
  • batched_indices_in (Tensor): Batched indices for input sparse linear layer

  • batched_indices_out (Tensor): Batched indices for output sparse linear layer

Return type:

tuple

Example

>>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3)
>>> # Get batch parameters for batch size 32
>>> batch_params = model.get_batch_params(32, torch.device('cuda'))
>>> # Parameters are cached for subsequent calls
>>> same_params = model.get_batch_params(32, torch.device('cuda'))
>>> # Different batch size triggers recomputation
>>> new_params = model.get_batch_params(64, torch.device('cuda'))
get_node_activations(x, agg='sum', inference=True)[source]
get_node_attention(x)[source]

Return per-layer node-level attention weights.

Parameters:

x (Tensor (B, num_input_nodes)) – Input features; typically supply a single sample (B=1).

Returns:

Mapping from node name to a tensor of shape (L, B) with attention weights per layer (L) and batch element (B).

Return type:

Dict[str, Tensor]

prune(threshold=0.01, verbose=False)[source]

Prunes the model by removing channels with small weights.

This method removes channels whose maximum absolute weight value across all layers is below the specified threshold. This can significantly reduce model size while maintaining performance. Remember to reinitialize the optimizer after pruning if using during training.

Parameters:
  • threshold (float, optional) – The threshold below which weights are considered insignificant. (default: 1e-2)

  • verbose (bool, optional) – If set to True, print pruning statistics. (default: False)

Returns:

Number of parameters removed by pruning.

Return type:

int

Example

>>> # Create a model with 16 channels per function node
>>> model = GSNN(edge_index_dict, node_names_dict, channels=16, layers=3)
>>> # Train the model...
>>> # Prune channels with small weights
>>> removed_params = model.prune(threshold=1e-2, verbose=True)
>>> print(f'Removed {removed_params} parameters')
class gsnn.gsnn.models.GSNN.NodeAttention(*args: Any, **kwargs: Any)[source]

Bases: Module

Node-wise channel attention.

The layer learns a single scalar attention coefficient (alpha_{b,n}) per node n for every sample in the batch b. The coefficient is obtained by first aggregating the (optionally weighted) hidden channels that belong to the node and then normalising the aggregated scores across all nodes with a sigmoid gates per node (no cross-node normalization). The resulting attention weights can be:

  1. Interpreted - (alpha_{b,n}) tells how important node n was for the current forward pass.

  2. Applied - the coefficients are broadcast back to the individual channels that originated from the node and multiplied with the original activations, producing an attention-modulated output.

Parameters:
  • channel_groups (Sequence[int] or Tensor) – A 1-D list/array mapping global channel indexnode index. Length equals the total number of hidden channels across all nodes.

  • dropout (float, optional (default=0.0)) – Dropout probability applied to the node-level attention weights.

  • temperature (float, optional (default=1.0)) – Softmax temperature. Lower values produce sharper distributions.

Examples

>>> # Suppose we have 2 nodes with 3 channels each (total 6 channels)
>>> ch_groups = [0, 0, 0, 1, 1, 1]
>>> attn = NodeAttention(ch_groups, dropout=0.1)
>>> x = torch.randn(8, 6)  # (batch=8, channels=6)
>>> out, alpha = attn(x, return_alpha=True)
>>> out.shape          # same shape as input
torch.Size([8, 6])
>>> alpha.shape        # one scalar per node
torch.Size([8, 2])
forward(x: torch.Tensor, *, return_alpha: bool = False)[source]

Apply node attention.

Parameters:
  • x (Tensor of shape (B, C)) – Input activations ordered so that channels belonging to the same node are indexed according to channel_groups.

  • return_alpha (bool, optional (default=False)) – If True, the method returns a tuple (out, alpha) where alpha is the attention matrix of shape (B, n_nodes).

Returns:

The attention-modulated activations (and, optionally, the node coefficients).

Return type:

Tensor or Tuple[Tensor, Tensor]

class gsnn.gsnn.models.GSNN.NodeMLP(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x: torch.Tensor) torch.Tensor[source]
class gsnn.gsnn.models.GSNN.ResBlock(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x, batch_params, node_err=None)[source]

Implements the forward pass of the residual block.

The forward pass consists of:
  1. Edge batch normalization (if configured)

  2. Input sparse linear transformation

  3. Optional node error addition

  4. Normalization and nonlinearity

  5. Node masking (if configured)

  6. Output sparse linear transformation

  7. Dropout

  8. Residual connection (if enabled)

Parameters:
  • x (Tensor) – Edge features of shape [batch_size, num_edges] or [batch_size, num_edges, 1].

  • batch_params (tuple) – A tuple containing: - batched_indices_in (Tensor): Batched indices for input sparse linear layer - batched_indices_out (Tensor): Batched indices for output sparse linear layer

  • node_err (Tensor, optional) – Node-level error terms to be added after input transformation. Shape [batch_size, num_nodes]. (default: None)

Returns:

Transformed edge features of shape [batch_size, num_edges].

Return type:

Tensor

Example

>>> # Using the block from the class example
>>> x = torch.randn(32, 2)  # [batch_size, num_edges]
>>> # Create batched indices (normally done by GSNN)
>>> batch_in = torch.tensor([[0, 0], [0, 1]])
>>> batch_out = torch.tensor([[0, 1], [0, 0]])
>>> batch_params = (batch_in, batch_out)
>>> # Forward pass
>>> out = block(x, batch_params)
>>> print(out.shape)  # [32, 2]
>>> # With node errors
>>> node_err = torch.randn(32, 1)  # [batch_size, num_nodes]
>>> out = block(x, batch_params, node_err=node_err)
>>> print(out.shape)  # [32, 2]
set_node_mask(mask)[source]

Set a mask to restrict which channels or nodes are active in the computation.

Parameters:

mask (torch.Tensor) – A boolean mask indicating which positions remain active.

class gsnn.gsnn.models.GSNN.SignedMessagePassing(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

forward(x)[source]
message(x_i, x_j, edge_weight)[source]
gsnn.gsnn.models.GSNN.apply_norm_and_nonlin(norm, nonlin, out, norm_first)[source]

Apply normalization and nonlinearity to the input tensor.

Parameters:
  • norm (callable) – Normalization layer or operation.

  • nonlin (callable) – Nonlinear activation function.

  • out (Tensor) – Input tensor to be normalized and activated.

  • norm_first (bool) – If True, apply normalization before nonlinearity.

Returns:

The transformed tensor.

Return type:

Tensor

Example

>>> norm = torch.nn.BatchNorm1d(32)
>>> nonlin = torch.nn.ReLU()
>>> x = torch.randn(16, 32)  # [batch_size, num_features]
>>> # Apply normalization first
>>> out = apply_norm_and_nonlin(norm, nonlin, x, norm_first=True)
>>> print(out.shape)  # [16, 32]
gsnn.gsnn.models.GSNN.edge2node(x, edge_index, output_node_mask)[source]

Convert edge-level features back to node-level features, focusing on output nodes.

Typically, output nodes should be designed to have an in-degree of 1, however, in the case of multiple edges per output node, the output features are summed and normalized by the square root of the in-degree.

Parameters:
  • x (Tensor) – Edge features of shape [batch_size, num_edges].

  • edge_index (Tensor) – Edge indices of shape [2, num_edges].

  • output_node_mask (Tensor) – Boolean mask of shape [num_nodes] indicating output nodes.

Returns:

Node features of shape [batch_size, num_output_nodes].

Return type:

Tensor

Example

>>> x = torch.randn(32, 3)  # [batch_size, num_edges]
>>> edge_index = torch.tensor([[0, 1, 1], [2, 2, 3]])  # 3 edges
>>> output_mask = torch.tensor([0, 0, 1, 1])  # Nodes 2,3 are outputs
>>> node_features = edge2node(x, edge_index, output_mask)
>>> print(node_features.shape)  # [32, 2]
gsnn.gsnn.models.GSNN.get_Win_indices(edge_index, channels, function_nodes)[source]

Build sparse COO indices for the input weight matrix \(W_{in}\).

Parameters:
  • edge_index (Tensor) – Homogeneous edge index of shape [2, num_edges].

  • channels (int or Tensor) – If int, every function node gets the same number of hidden channels. If 1-D tensor/array, it must contain the per-node channel count of length num_nodes.

  • function_nodes (Tensor) – Index list of nodes that represent functions.

Returns:

A tuple containing:
  • indices (Tensor): COO indices of shape [2, nnz] for sparse tensor construction

  • channel_count (numpy.ndarray): Per-node channel counts for later reuse

Return type:

tuple

Example

>>> edge_index = torch.tensor([[0, 1], [1, 0]])  # 2 edges
>>> channels = 3  # 3 channels per function node
>>> function_nodes = torch.tensor([0])  # Node 0 is a function node
>>> indices, counts = get_Win_indices(edge_index, channels, function_nodes)
>>> print(indices.shape)  # [2, 6] (2 edges * 3 channels)
>>> print(counts)  # [3, 0] (3 channels for node 0, 0 for node 1)
gsnn.gsnn.models.GSNN.get_Wout_indices(edge_index, function_nodes, channels)[source]

Build sparse COO indices for the output weight matrix \(W_{out}\).

Parameters:
  • edge_index (Tensor) – Homogeneous edge index of shape [2, num_edges].

  • function_nodes (Tensor) – Index list of nodes that represent functions.

  • channels (numpy.ndarray) – Array indicating the number of channels for each node.

Returns:

COO indices of shape [2, nnz] for sparse tensor construction.

Return type:

Tensor

Example

>>> edge_index = torch.tensor([[0, 1], [1, 0]])  # 2 edges
>>> function_nodes = torch.tensor([0])  # Node 0 is a function node
>>> channels = np.array([3, 0])  # 3 channels for node 0, 0 for node 1
>>> indices = get_Wout_indices(edge_index, function_nodes, channels)
>>> print(indices.shape)  # [2, 6] (3 channels * 2 edges)
gsnn.gsnn.models.GSNN.get_conv_indices(edge_index, channels, function_nodes)[source]

Compute indexing structures for convolutional (sparse linear) layers.

Parameters:
  • edge_index (Tensor) – Homogeneous edge indices of shape [2, num_edges].

  • channels (int) – Number of channels per function node.

  • function_nodes (Tensor) – Indices of function nodes.

Returns:

A tuple containing:
  • w_in_indices (Tensor): Indexing for \(W_{in}\)

  • w_out_indices (Tensor): Indexing for \(W_{out}\)

  • w_in_size (tuple): Size specification for \(W_{in}\)

  • w_out_size (tuple): Size specification for \(W_{out}\)

  • channel_groups (List[int]): List mapping each channel to its node

Return type:

tuple

Example

>>> edge_index = torch.tensor([[0, 1], [1, 0]])  # 2 edges
>>> channels = 3  # 3 channels per function node
>>> function_nodes = torch.tensor([0])  # Node 0 is a function node
>>> indices = get_conv_indices(edge_index, channels, function_nodes)
>>> print(len(indices))  # 5 (w_in_indices, w_out_indices, sizes, groups)
gsnn.gsnn.models.GSNN.hetero2homo(edge_index_dict, node_names_dict, edge_weight_dict=None)[source]

Convert a heterogeneous GSNN graph into a homogeneous graph representation.

The GSNN pipeline distinguishes three edge types:
  1. (‘input’, ‘to’, ‘function’)

  2. (‘function’, ‘to’, ‘function’)

  3. (‘function’, ‘to’, ‘output’)

This function stacks these edge sets into one homogeneous graph and returns boolean masks that let you recover the original node semantics.

Parameters:
  • edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge-type mapping where each value is a LongTensor with shape [2, num_edges_of_type].

  • node_names_dict (Dict[str, List[str]]) – Mapping of node types (‘input’, ‘function’, ‘output’) to their respective node names.

  • edge_weight_dict (Dict[Tuple[str, str, str], Tensor]) – Dictionary mapping edge types to edge weights. Expected keys are (‘input’, ‘to’, ‘function’), (‘function’, ‘to’, ‘function’), and (‘function’, ‘to’, ‘output’). Values should be tensors of shape [num_edges]. (default: None)

Returns:

A tuple containing:
  • edge_index (Tensor): Homogeneous edge indices of shape [2, num_edges]

  • input_mask (Tensor): Boolean mask for input nodes of shape [num_nodes]

  • output_mask (Tensor): Boolean mask for output nodes of shape [num_nodes]

  • num_nodes (int): Total number of nodes in the homogeneous graph

  • homo_names (List[str]): Node names in homogeneous ordering

  • edge_weight (Optional[Tensor]): Homogeneous edge weights of shape [num_edges], or None if edge_weight_dict was None.

Return type:

tuple

Example

>>> edge_index_dict = {
...     ('input', 'to', 'function'): torch.tensor([[0, 1], [0, 0]]),
...     ('function', 'to', 'function'): torch.tensor([[0], [0]]),
...     ('function', 'to', 'output'): torch.tensor([[0], [0]])
... }
>>> node_names_dict = {
...     'input': ['in1', 'in2'],
...     'function': ['func1'],
...     'output': ['out1']
... }
>>> edge_index, in_mask, out_mask, n_nodes, names = hetero2homo(
...     edge_index_dict, node_names_dict
... )
>>> print(edge_index.shape)  # [2, 4]
>>> print(in_mask.sum())  # 2 (number of input nodes)
>>> print(out_mask.sum())  # 1 (number of output nodes)
gsnn.gsnn.models.GSNN.node2edge(x, edge_index)[source]

Convert node-level features to edge-level features. Every out-going edge receives the feature of the source node.

Parameters:
  • x (Tensor) – Node features of shape [batch_size, num_nodes].

  • edge_index (Tensor) – Edge indices of shape [2, num_edges].

Returns:

Edge features of shape [batch_size, num_edges].

Return type:

Tensor

Example

>>> x = torch.randn(32, 4)  # [batch_size, num_nodes]
>>> edge_index = torch.tensor([[0, 1], [1, 2]])  # 2 edges
>>> edge_features = node2edge(x, edge_index)
>>> print(edge_features.shape)  # [32, 2]