Source code for gsnn.models.ChannelEMANorm

"""
Channel-wise Exponential Moving Average Normalization.

ChannelEMANorm maintains running statistics using exponential moving averages for each 
individual channel independently, making it very robust for small and variable batch sizes.
"""

import torch
import torch.nn as nn


[docs]class ChannelEMANorm(nn.Module): """ Applies normalization per individual channel using exponential moving averages. This normalization maintains running statistics for each channel independently but doesn't use current batch statistics for normalization, making it very stable for small batch sizes. Args: num_channels (int): Number of channels to normalize. eps (float): Small value to avoid division by zero. Default: 1e-5 momentum (float): Momentum for updating running statistics. Default: 0.1 affine (bool): If True, applies learnable scale and bias per channel. Default: True track_running_stats (bool): If True, maintains running statistics. Default: True """ def __init__(self, num_channels, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super().__init__() self.num_channels = num_channels self.eps = eps self.momentum = momentum self.track_running_stats = track_running_stats # Learnable parameters per channel if affine: self.gamma = nn.Parameter(torch.ones(num_channels)) self.beta = nn.Parameter(torch.zeros(num_channels)) else: self.register_parameter('gamma', None) self.register_parameter('beta', None) if self.track_running_stats: # Running statistics per channel self.register_buffer('running_mean', torch.zeros(num_channels)) self.register_buffer('running_var', torch.ones(num_channels)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None)
[docs] def forward(self, x): # Handle input shape (B, C) or (B, C, 1) original_shape = x.shape if x.dim() == 3 and x.size(-1) == 1: x = x.squeeze(-1) # Ensure (B, C) shape B, C = x.shape if C != self.num_channels: raise ValueError(f"Expected {self.num_channels} channels, got {C}") # Update running statistics during training if self.training and self.track_running_stats: with torch.no_grad(): # Compute current batch statistics per channel batch_mean = x.mean(dim=0) # Shape: (C,) batch_var = x.var(dim=0, unbiased=False) # Shape: (C,) # Update running statistics with exponential moving average self.num_batches_tracked += 1 momentum = self.momentum self.running_mean = (1 - momentum) * self.running_mean + momentum * batch_mean self.running_var = (1 - momentum) * self.running_var + momentum * batch_var # Always use running statistics for normalization (key difference from batch norm) if self.track_running_stats: mean = self.running_mean var = self.running_var else: # Fallback: compute statistics from current batch if no running stats mean = x.mean(dim=0) var = x.var(dim=0, unbiased=False) # Normalize using running statistics x_normed = (x - mean) / torch.sqrt(var + self.eps) # Apply learnable affine transformation if self.gamma is not None and self.beta is not None: x_normed = x_normed * self.gamma + self.beta # Restore original shape if len(original_shape) == 3: x_normed = x_normed.unsqueeze(-1) return x_normed