gsnn.models.GroupBatchNorm

Classes

GroupBatchNorm(*args, **kwargs)

A batch-norm style module that:

class gsnn.models.GroupBatchNorm.GroupBatchNorm(*args: Any, **kwargs: Any)[source]

Bases: Module

A batch-norm style module that:
  • Partitions the C channels into groups via ‘channel_groups’.

  • Computes mean/var for each group across the entire batch dimension.

  • Maintains running stats for inference (if track_running_stats=True).

forward(x)[source]

x: (B, C) or (B, C, 1). We first squeeze the last dim if necessary.