Source code for gsnn.models.GroupBatchNorm

import torch
import torch.nn as nn
import torch_geometric as pyg

[docs]class GroupBatchNorm(nn.Module): """ A batch-norm style module that: - Partitions the C channels into groups via 'channel_groups'. - Computes mean/var for each group across the entire batch dimension. - Maintains running stats for inference (if track_running_stats=True). """ def __init__(self, channel_groups, eps=1e-5, momentum=0.1, affine=False, track_running_stats=True): super().__init__() # channel_groups is a tensor of shape (C,) mapping each channel to a group index self.register_buffer('channel_groups', torch.tensor(channel_groups, dtype=torch.long)) # Number of distinct groups: num_groups = self.channel_groups.max().item() + 1 self.eps = eps self.momentum = momentum self.track_running_stats = track_running_stats # Optional learnable parameters gamma, beta per group if affine: self.gamma = nn.Parameter(torch.ones(num_groups)) self.beta = nn.Parameter(torch.zeros(num_groups)) else: self.register_parameter('gamma', None) self.register_parameter('beta', None) if self.track_running_stats: # Registers running mean and var for each group self.register_buffer('running_mean', torch.zeros(num_groups)) self.register_buffer('running_var', torch.ones(num_groups)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None)
[docs] def forward(self, x): """ x: (B, C) or (B, C, 1). We first squeeze the last dim if necessary. """ # Squeeze last dim if (B, C, 1) if x.dim() == 3 and x.size(-1) == 1: x = x.squeeze(-1) # now (B, C) B, C = x.shape if self.training or (not self.track_running_stats): # ----- Compute batch mean & var for each group efficiently ----- # Use scatter on channel dimension to avoid flattening group_means = pyg.utils.scatter(x, self.channel_groups, dim=1, reduce='mean') # (B, num_groups) # Compute group means across batch dimension batch_group_means = group_means.mean(dim=0) # (num_groups,) # For variance, we need to compute (x - group_mean)^2 for each group # Broadcast group means back to original shape efficiently expanded_means = batch_group_means.index_select(0, self.channel_groups).unsqueeze(0) # (1, C) # Compute variance using the broadcast mean centered = x - expanded_means # (B, C) group_vars_per_batch = pyg.utils.scatter(centered**2, self.channel_groups, dim=1, reduce='mean') # (B, num_groups) batch_group_vars = group_vars_per_batch.mean(dim=0) # (num_groups,) # If we're tracking running stats, update them if self.track_running_stats: with torch.no_grad(): self.num_batches_tracked += 1 momentum = self.momentum self.running_mean = (1 - momentum)*self.running_mean + momentum*batch_group_means self.running_var = (1 - momentum)*self.running_var + momentum*batch_group_vars else: # ----- Use running stats (inference mode) ----- batch_group_means = self.running_mean batch_group_vars = self.running_var # ----- Apply normalization efficiently ----- # Broadcast means and vars back to (B, C) shape expanded_means = batch_group_means.index_select(0, self.channel_groups).unsqueeze(0) # (1, C) expanded_vars = batch_group_vars.index_select(0, self.channel_groups).unsqueeze(0) # (1, C) # Normalize x_normalized = (x - expanded_means) / torch.sqrt(expanded_vars + self.eps) # Apply learnable affine transform if provided if self.gamma is not None and self.beta is not None: gamma_expanded = self.gamma.index_select(0, self.channel_groups).unsqueeze(0) # (1, C) beta_expanded = self.beta.index_select(0, self.channel_groups).unsqueeze(0) # (1, C) x_normalized = x_normalized * gamma_expanded + beta_expanded return x_normalized.unsqueeze(-1) # to match the original shape (B, C, 1) if needed