Source code for gsnn.models.SoftmaxGroupNorm

import torch
import torch_geometric as pyg

[docs]class SoftmaxGroupNorm(torch.nn.Module): """Channel-wise softmax normalized within each channel group (stable softmax via per-group max shift)."""
[docs] def __init__(self, channel_groups, eps=1e-8): """ Args: channel_groups: Length-``C`` index assigning each channel to a group. eps: Added to the denominator for numerical stability. """ super().__init__() self.register_buffer('channel_groups', torch.tensor(channel_groups, dtype=torch.long)) unique_groups, counts = torch.unique(self.channel_groups, return_counts=True) self.register_buffer('n_channels', counts) self.eps = eps
[docs] def forward(self, x): """Input ``(B, C)`` or ``(B, C, 1)``; returns group-softmax-normalized activations.""" if x.size(-1) == 1: x = x.squeeze(-1) # Compute per-group maxima for numerical stability (stable softmax) max_values = pyg.utils.scatter(x, self.channel_groups, dim=1, reduce='max') expanded_max_values = max_values.index_select(1, self.channel_groups) # Exponentiate shifted values exp_x = torch.exp(x - expanded_max_values) # Compute sum of exponentials per group sum_exp = pyg.utils.scatter(exp_x, self.channel_groups, dim=1, reduce='sum') expanded_sum = sum_exp.index_select(1, self.channel_groups) + self.eps # Compute stable softmax x = exp_x / expanded_sum # Restore trailing dimension if needed if x.dim() == 2: x = x.unsqueeze(-1) return x