gsnn.gsnn.models.ResBlock
- class gsnn.gsnn.models.ResBlock(*args: Any, **kwargs: Any)[source]
Bases:
Module- __init__(bias, nonlin, indices_params, dropout=0.0, norm='layer', init='xavier', lin_in=None, lin_out=None, residual=True, norm_first=True, node_attn=False, attn_mlp_hidden=32, learn_residual=True, affine=True, node_mlp=True, node_mlp_hidden=32, edge_index=None, edge_weight=None)[source]
A residual block for GSNN that applies sparse linear transformations with optional normalization.
- Each ResBlock consists of:
Input sparse linear transformation (W_in)
Normalization (optional)
Nonlinearity
Output sparse linear transformation (W_out)
Residual connection (optional)
The block operates on edge features and uses sparse linear layers to maintain the graph structure constraints.
- Parameters:
bias (bool) – If set to
False, the layers will not learn an additive bias.nonlin (torch.nn.Module) – Activation function class (e.g.,
torch.nn.ELU).indices_params (tuple) – A tuple containing: - w_in_indices (Tensor): Indices for input sparse linear layer - w_out_indices (Tensor): Indices for output sparse linear layer - w_in_size (tuple): Size specification for input layer - w_out_size (tuple): Size specification for output layer - channel_groups (list): Mapping of channels to their respective nodes
dropout (float, optional) – Dropout probability. (default:
0.)norm (str, optional) – Normalization type (
'layer','batch','softmax','groupbatch','edgebatch','rms','ema','channelema'or'none'). (default:'layer')init (str, optional) – Weight initialization strategy (
'xavier'or'kaiming'). (default:'xavier')lin_in (SparseLinear, optional) – Predefined input linear layer. If
None, constructed from indices_params. (default:None)lin_out (SparseLinear, optional) – Predefined output linear layer. If
None, constructed from indices_params. (default:None)residual (bool, optional) – If set to
True, adds residual connections. (default:True)norm_first (bool, optional) – If set to
True, apply normalization before nonlinearity. (default:True)learn_residual (bool, optional) – If set to
True, learn the residual connection. (default:True)affine (bool, optional) – If set to
True, the normalization layers will learn an additive bias and scale. (default:True)node_mlp (bool, optional) – If set to
True, apply additional MLP processing per node to enhance representational capacity while maintaining graph structure constraints. (default:True)node_mlp_hidden (int, optional) – Hidden dimension size for the node MLP when enabled. (default:
128)
Example
>>> # Create indices for a simple graph with 2 edges and 1 function node with 3 channels >>> w_in_indices = torch.tensor([[0, 1], [0, 1]]) # 2 edges, 2 channels >>> w_out_indices = torch.tensor([[0, 1], [0, 1]]) >>> w_in_size = (2, 3) # (num_edges, num_channels) >>> w_out_size = (3, 2) # (num_channels, num_edges) >>> channel_groups = [0, 0, 0] # All channels belong to node 0 >>> indices_params = (w_in_indices, w_out_indices, w_in_size, w_out_size, channel_groups) >>> # Create ResBlock >>> block = ResBlock( ... bias=True, ... nonlin=torch.nn.ELU, ... indices_params=indices_params ... ) >>> # Forward pass >>> x = torch.randn(32, 2) # [batch_size, num_edges] >>> batch_params = (None, None) # Normally computed by GSNN >>> out = block(x, batch_params) >>> print(out.shape) # [32, 2]
Methods
__init__(bias, nonlin, indices_params[, ...])A residual block for GSNN that applies sparse linear transformations with optional normalization.
forward(x, batch_params[, node_err, fn_activity])Implements the forward pass of the residual block.
set_node_mask(mask)Set a mask to restrict which channels or nodes are active in the computation.
- forward(x, batch_params, node_err=None, fn_activity=None)[source]
Implements the forward pass of the residual block.
- The forward pass consists of:
Edge batch normalization (if configured)
Input sparse linear transformation
Optional node error addition
Normalization and nonlinearity
Node masking (if configured)
Output sparse linear transformation
Dropout
Residual connection (if enabled)
- Parameters:
x (Tensor) – Edge features of shape
[batch_size, num_edges]or[batch_size, num_edges, 1].batch_params (tuple) – A tuple containing: - batched_indices_in (Tensor): Batched indices for input sparse linear layer - batched_indices_out (Tensor): Batched indices for output sparse linear layer
node_err (Tensor, optional) – Node-level error terms to be added after input transformation. Shape
[batch_size, num_nodes]. (default:None)fn_activity (Tensor, optional) – Function node activity of shape
[batch_size, num_function_nodes]. Used for computing node activity. (default:None)
- Returns:
Transformed edge features of shape
[batch_size, num_edges].- Return type:
Tensor
Example
>>> # Using the block from the class example >>> x = torch.randn(32, 2) # [batch_size, num_edges] >>> # Create batched indices (normally done by GSNN) >>> batch_in = torch.tensor([[0, 0], [0, 1]]) >>> batch_out = torch.tensor([[0, 1], [0, 0]]) >>> batch_params = (batch_in, batch_out) >>> # Forward pass >>> out = block(x, batch_params) >>> print(out.shape) # [32, 2] >>> # With node errors >>> node_err = torch.randn(32, 1) # [batch_size, num_nodes] >>> out = block(x, batch_params, node_err=node_err) >>> print(out.shape) # [32, 2]
- set_node_mask(mask)[source]
Set a mask to restrict which channels or nodes are active in the computation.
- Parameters:
mask (torch.Tensor) – A boolean mask indicating which positions remain active.