Source code for gsnn.models.GroupRMSNorm

"""
Group-wise Root Mean Square Layer Normalization.

RMSNorm is a simpler alternative to layer normalization that only uses the RMS 
for normalization without mean centering. It's particularly stable for small 
batch sizes and computationally more efficient than layer norm.
"""

import torch
import torch_geometric as pyg


[docs]class GroupRMSNorm(torch.nn.Module): """ Applies Root Mean Square normalization within each channel group. RMSNorm normalizes using only the RMS (root mean square) without mean centering, making it simpler and more stable than layer normalization, especially for small batch sizes. Args: channel_groups (list or tensor): Specifies which group each channel belongs to. For example, [0,0,1,1,2,2] specifies 3 groups with 2 channels each. eps (float): Small value to avoid division by zero. Default: 1e-6 affine (bool): If True, applies learnable scale parameter. Default: True """
[docs] def __init__(self, channel_groups, eps=1e-6, affine=True): super().__init__() self.register_buffer('channel_groups', torch.tensor(channel_groups, dtype=torch.long)) self.register_buffer('n_channels', torch.unique(self.channel_groups, return_counts=True)[1]) self.eps = eps if affine: # Only scale parameter (gamma), no bias since we don't center N = torch.max(self.channel_groups).item() + 1 self.gamma = torch.nn.Parameter(torch.ones(N)) else: self.register_parameter('gamma', None)
[docs] def forward(self, x): """Normalize by group RMS; accepts ``(B, C)`` or ``(B, C, 1)``.""" # Handle input shape (B, C) or (B, C, 1) original_shape = x.shape x = x.squeeze(-1) # Ensure (B, C) shape # Compute RMS per group # RMS = sqrt(mean(x^2)) x_squared = x ** 2 rms_squared = pyg.utils.scatter(x_squared, self.channel_groups, dim=1, reduce='mean') rms = torch.sqrt(rms_squared + self.eps) # Expand RMS to match input dimensions expanded_rms = rms.index_select(1, self.channel_groups) # Normalize by RMS x_normed = x / expanded_rms # Apply learnable scale if enabled if self.gamma is not None: x_normed = x_normed * self.gamma[self.channel_groups].unsqueeze(0) # Restore original shape if len(original_shape) == 3: x_normed = x_normed.unsqueeze(-1) return x_normed