gsnn.models.GroupBatchNorm
- class gsnn.models.GroupBatchNorm(*args: Any, **kwargs: Any)[source]
Bases:
Module- A batch-norm style module that:
Partitions the C channels into groups via ‘channel_groups’.
Computes mean/var for each group across the entire batch dimension.
Maintains running stats for inference (if track_running_stats=True).
Methods
__init__(channel_groups[, eps, momentum, ...])forward(x)x: (B, C) or (B, C, 1).