Source code for gsnn.models.NodeAttention

import torch
import torch.nn as nn
from gsnn.models.SignedMessagePassing import SignedMessagePassing

[docs]class NodeAttention(torch.nn.Module): r"""Node-wise channel attention. The layer learns a single scalar attention coefficient \(\alpha_{b,n}\) per **node** *n* for every sample in the batch *b*. The coefficient is obtained by first aggregating the (optionally weighted) hidden channels that belong to the node and then normalising the aggregated scores across all nodes with a sigmoid gates per node (no cross-node normalization). The resulting attention weights can be: 1. **Interpreted** - \(\alpha_{b,n}\) tells how important node *n* was for the current forward pass. 2. **Applied** - the coefficients are broadcast back to the individual channels that originated from the node and multiplied with the original activations, producing an attention-modulated output. Parameters ---------- channel_groups : Sequence[int] or Tensor A 1-D list/array mapping *global channel index* → *node index*. Length equals the total number of hidden channels across all nodes. dropout : float, optional (default=0.0) Dropout probability applied to the node-level attention weights. temperature : float, optional (default=1.0) Softmax temperature. Lower values produce sharper distributions. Examples -------- >>> # Suppose we have 2 nodes with 3 channels each (total 6 channels) >>> ch_groups = [0, 0, 0, 1, 1, 1] >>> attn = NodeAttention(ch_groups, dropout=0.1) >>> x = torch.randn(8, 6) # (batch=8, channels=6) >>> out, alpha = attn(x, return_alpha=True) >>> out.shape # same shape as input torch.Size([8, 6]) >>> alpha.shape # one scalar per node torch.Size([8, 2]) """
[docs] def __init__(self, channel_groups, dropout: float = 0.0, temperature: float = 1.0, channels=16, edge_index=None, edge_weight=None): super().__init__() # Convert and store channel → node mapping self.register_buffer('channel_groups', torch.as_tensor(channel_groups, dtype=torch.long)) self.n_nodes: int = int(self.channel_groups.max().item() + 1) self.dropout: float = float(dropout) self.temperature: float = float(temperature) if (edge_weight is not None) and (edge_index is not None): self.signed_message_passing = SignedMessagePassing(edge_weight, edge_index) else: self.signed_message_passing = torch.nn.Identity() # Compute how many channels belong to each node (assume uniform) self.channels_per_node: int = int(self.channel_groups.numel() // self.n_nodes) # Shared two-layer MLP that maps a vector of node-channels → scalar gate self.mlp = nn.Sequential( nn.Linear(self.channels_per_node, channels), nn.ELU(), nn.LayerNorm(channels), nn.Linear(channels, 1), ) # will hold the last computed attention for inspection self._last_alpha = None
[docs] def forward(self, x: torch.Tensor, *, return_alpha: bool = False): """Apply node attention. Parameters ---------- x : Tensor of shape (B, C) Input activations ordered so that channels belonging to the same node are indexed according to `channel_groups`. return_alpha : bool, optional (default=False) If *True*, the method returns a tuple ``(out, alpha)`` where ``alpha`` is the attention matrix of shape *(B, n_nodes)*. Returns ------- Tensor or Tuple[Tensor, Tensor] The attention-modulated activations (and, optionally, the node coefficients). """ x = x.squeeze(-1) B, C = x.shape if C != self.channel_groups.numel(): raise ValueError( f"Expected input with {self.channel_groups.numel()} channels, got {C}.") x_nodes = x.view(B, self.n_nodes, self.channels_per_node) # (B, N, C_pn) alpha = self.mlp(x_nodes).squeeze(-1) # (B, N); logits alpha = torch.nn.functional.dropout(alpha, p=self.dropout, training=self.training) # signed message passing (sum of neighbors * edge sign) alpha = self.signed_message_passing(alpha) alpha_per_channel = alpha.sigmoid().unsqueeze(-1).expand(-1, -1, self.channels_per_node) # (B,N,C_pn) out = (x_nodes * alpha_per_channel).reshape(B, C) # (B, C) self._last_alpha = alpha return (out, alpha) if return_alpha else out