gsnn.gsnn.optim.TrainingDiagnostics
- class gsnn.gsnn.optim.TrainingDiagnostics(model: torch.nn.Module, track_every: int = 10, window_size: int = 100, track_activations: bool = True, track_weights: bool = True, track_gradients: bool = True, track_curvature: bool = False, verbose: bool = True)[source]
Bases:
objectComprehensive training diagnostics for monitoring model optimization.
Tracks gradient flow, activation patterns, weight spectral properties, and basic curvature metrics to help diagnose training issues and guide optimization decisions.
- Usage:
diagnostics = TrainingDiagnostics(model, track_every=10)
# During training loop: loss.backward() diagnostics.update(model, loss.item(), batch_idx) optimizer.step()
# Generate reports: summary = diagnostics.get_summary() diagnostics.plot_diagnostics()
- __init__(model: torch.nn.Module, track_every: int = 10, window_size: int = 100, track_activations: bool = True, track_weights: bool = True, track_gradients: bool = True, track_curvature: bool = False, verbose: bool = True)[source]
Initialize training diagnostics tracker.
- Parameters:
model – PyTorch model to monitor
track_every – Update diagnostics every N steps
window_size – Size of rolling window for statistics
track_activations – Whether to track activation statistics
track_weights – Whether to track weight spectral properties
track_gradients – Whether to track gradient flow metrics
track_curvature – Whether to track curvature (expensive!)
verbose – Whether to print diagnostic warnings
Methods
__init__(model[, track_every, window_size, ...])Initialize training diagnostics tracker.
get_gradient_histogram_data()Get gradient histograms for current step.
get_summary()Generate a comprehensive diagnostic summary.
plot_diagnostics([save_path, figsize])Generate comprehensive diagnostic plots.
reset()Reset all diagnostic tracking.
update(model, loss[, step])Update diagnostics with current training state.
- get_gradient_histogram_data() Dict[str, numpy.ndarray][source]
Get gradient histograms for current step.
- Returns:
Dictionary mapping layer names to gradient histograms
- get_summary() Dict[str, Any][source]
Generate a comprehensive diagnostic summary.
- Returns:
Dictionary containing current diagnostic state and recommendations