gsnn.gsnn.models.ResBlock

class gsnn.gsnn.models.ResBlock(*args: Any, **kwargs: Any)[source]

Bases: Module

__init__(bias, nonlin, indices_params, dropout=0.0, norm='layer', init='xavier', lin_in=None, lin_out=None, residual=True, norm_first=True, node_attn=False, attn_mlp_hidden=32, learn_residual=True, affine=True, node_mlp=True, node_mlp_hidden=32, edge_index=None, edge_weight=None)[source]

A residual block for GSNN that applies sparse linear transformations with optional normalization.

Each ResBlock consists of:
  1. Input sparse linear transformation (W_in)

  2. Normalization (optional)

  3. Nonlinearity

  4. Output sparse linear transformation (W_out)

  5. Residual connection (optional)

The block operates on edge features and uses sparse linear layers to maintain the graph structure constraints.

Parameters:
  • bias (bool) – If set to False, the layers will not learn an additive bias.

  • nonlin (torch.nn.Module) – Activation function class (e.g., torch.nn.ELU).

  • indices_params (tuple) – A tuple containing: - w_in_indices (Tensor): Indices for input sparse linear layer - w_out_indices (Tensor): Indices for output sparse linear layer - w_in_size (tuple): Size specification for input layer - w_out_size (tuple): Size specification for output layer - channel_groups (list): Mapping of channels to their respective nodes

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • norm (str, optional) – Normalization type ('layer', 'batch', 'softmax', 'groupbatch', 'edgebatch', 'rms', 'ema', 'channelema' or 'none'). (default: 'layer')

  • init (str, optional) – Weight initialization strategy ('xavier' or 'kaiming'). (default: 'xavier')

  • lin_in (SparseLinear, optional) – Predefined input linear layer. If None, constructed from indices_params. (default: None)

  • lin_out (SparseLinear, optional) – Predefined output linear layer. If None, constructed from indices_params. (default: None)

  • residual (bool, optional) – If set to True, adds residual connections. (default: True)

  • norm_first (bool, optional) – If set to True, apply normalization before nonlinearity. (default: True)

  • learn_residual (bool, optional) – If set to True, learn the residual connection. (default: True)

  • affine (bool, optional) – If set to True, the normalization layers will learn an additive bias and scale. (default: True)

  • node_mlp (bool, optional) – If set to True, apply additional MLP processing per node to enhance representational capacity while maintaining graph structure constraints. (default: True)

  • node_mlp_hidden (int, optional) – Hidden dimension size for the node MLP when enabled. (default: 128)

Example

>>> # Create indices for a simple graph with 2 edges and 1 function node with 3 channels
>>> w_in_indices = torch.tensor([[0, 1], [0, 1]])  # 2 edges, 2 channels
>>> w_out_indices = torch.tensor([[0, 1], [0, 1]])
>>> w_in_size = (2, 3)  # (num_edges, num_channels)
>>> w_out_size = (3, 2)  # (num_channels, num_edges)
>>> channel_groups = [0, 0, 0]  # All channels belong to node 0
>>> indices_params = (w_in_indices, w_out_indices, w_in_size, w_out_size, channel_groups)
>>> # Create ResBlock
>>> block = ResBlock(
...     bias=True,
...     nonlin=torch.nn.ELU,
...     indices_params=indices_params
... )
>>> # Forward pass
>>> x = torch.randn(32, 2)  # [batch_size, num_edges]
>>> batch_params = (None, None)  # Normally computed by GSNN
>>> out = block(x, batch_params)
>>> print(out.shape)  # [32, 2]

Methods

__init__(bias, nonlin, indices_params[, ...])

A residual block for GSNN that applies sparse linear transformations with optional normalization.

forward(x, batch_params[, node_err, fn_activity])

Implements the forward pass of the residual block.

set_node_mask(mask)

Set a mask to restrict which channels or nodes are active in the computation.

forward(x, batch_params, node_err=None, fn_activity=None)[source]

Implements the forward pass of the residual block.

The forward pass consists of:
  1. Edge batch normalization (if configured)

  2. Input sparse linear transformation

  3. Optional node error addition

  4. Normalization and nonlinearity

  5. Node masking (if configured)

  6. Output sparse linear transformation

  7. Dropout

  8. Residual connection (if enabled)

Parameters:
  • x (Tensor) – Edge features of shape [batch_size, num_edges] or [batch_size, num_edges, 1].

  • batch_params (tuple) – A tuple containing: - batched_indices_in (Tensor): Batched indices for input sparse linear layer - batched_indices_out (Tensor): Batched indices for output sparse linear layer

  • node_err (Tensor, optional) – Node-level error terms to be added after input transformation. Shape [batch_size, num_nodes]. (default: None)

  • fn_activity (Tensor, optional) – Function node activity of shape [batch_size, num_function_nodes]. Used for computing node activity. (default: None)

Returns:

Transformed edge features of shape [batch_size, num_edges].

Return type:

Tensor

Example

>>> # Using the block from the class example
>>> x = torch.randn(32, 2)  # [batch_size, num_edges]
>>> # Create batched indices (normally done by GSNN)
>>> batch_in = torch.tensor([[0, 0], [0, 1]])
>>> batch_out = torch.tensor([[0, 1], [0, 0]])
>>> batch_params = (batch_in, batch_out)
>>> # Forward pass
>>> out = block(x, batch_params)
>>> print(out.shape)  # [32, 2]
>>> # With node errors
>>> node_err = torch.randn(32, 1)  # [batch_size, num_nodes]
>>> out = block(x, batch_params, node_err=node_err)
>>> print(out.shape)  # [32, 2]
set_node_mask(mask)[source]

Set a mask to restrict which channels or nodes are active in the computation.

Parameters:

mask (torch.Tensor) – A boolean mask indicating which positions remain active.