gsnn.optim.TrainingDiagnostics

class gsnn.optim.TrainingDiagnostics(model: torch.nn.Module, track_every: int = 10, window_size: int = 100, track_activations: bool = True, track_weights: bool = True, track_gradients: bool = True, track_curvature: bool = False, verbose: bool = True)[source]

Bases: object

Comprehensive training diagnostics for monitoring model optimization.

Tracks gradient flow, activation patterns, weight spectral properties, and basic curvature metrics to help diagnose training issues and guide optimization decisions.

Usage:

diagnostics = TrainingDiagnostics(model, track_every=10)

# During training loop: loss.backward() diagnostics.update(model, loss.item(), batch_idx) optimizer.step()

# Generate reports: summary = diagnostics.get_summary() diagnostics.plot_diagnostics()

__init__(model: torch.nn.Module, track_every: int = 10, window_size: int = 100, track_activations: bool = True, track_weights: bool = True, track_gradients: bool = True, track_curvature: bool = False, verbose: bool = True)[source]

Initialize training diagnostics tracker.

Parameters:
  • model – PyTorch model to monitor

  • track_every – Update diagnostics every N steps

  • window_size – Size of rolling window for statistics

  • track_activations – Whether to track activation statistics

  • track_weights – Whether to track weight spectral properties

  • track_gradients – Whether to track gradient flow metrics

  • track_curvature – Whether to track curvature (expensive!)

  • verbose – Whether to print diagnostic warnings

Methods

__init__(model[, track_every, window_size, ...])

Initialize training diagnostics tracker.

get_gradient_histogram_data()

Get gradient histograms for current step.

get_summary()

Generate a comprehensive diagnostic summary.

plot_diagnostics([save_path, figsize])

Generate comprehensive diagnostic plots.

reset()

Reset all diagnostic tracking.

update(model, loss[, step])

Update diagnostics with current training state.

get_gradient_histogram_data() Dict[str, numpy.ndarray][source]

Get gradient histograms for current step.

Returns:

Dictionary mapping layer names to gradient histograms

get_summary() Dict[str, Any][source]

Generate a comprehensive diagnostic summary.

Returns:

Dictionary containing current diagnostic state and recommendations

plot_diagnostics(save_path: Optional[str] = None, figsize: Tuple[int, int] = (15, 12))[source]

Generate comprehensive diagnostic plots.

Parameters:
  • save_path – Optional path to save the plot

  • figsize – Figure size tuple

reset()[source]

Reset all diagnostic tracking.

update(model: torch.nn.Module, loss: float, step: Optional[int] = None)[source]

Update diagnostics with current training state.

Parameters:
  • model – Current model state

  • loss – Current loss value

  • step – Optional step counter (uses internal if None)