gsnn.models.GroupBatchNorm

class gsnn.models.GroupBatchNorm(*args: Any, **kwargs: Any)[source]

Bases: Module

A batch-norm style module that:
  • Partitions the C channels into groups via ‘channel_groups’.

  • Computes mean/var for each group across the entire batch dimension.

  • Maintains running stats for inference (if track_running_stats=True).

__init__(channel_groups, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)[source]

Methods

__init__(channel_groups[, eps, momentum, ...])

forward(x)

x: (B, C) or (B, C, 1).

forward(x)[source]

x: (B, C) or (B, C, 1). We first squeeze the last dim if necessary.