gsnn.gsnn.optim
Classes
|
Comprehensive training diagnostics for monitoring model optimization. |
- class gsnn.gsnn.optim.TrainingDiagnostics(model: torch.nn.Module, track_every: int = 10, window_size: int = 100, track_activations: bool = True, track_weights: bool = True, track_gradients: bool = True, track_curvature: bool = False, verbose: bool = True)[source]
Bases:
objectComprehensive training diagnostics for monitoring model optimization.
Tracks gradient flow, activation patterns, weight spectral properties, and basic curvature metrics to help diagnose training issues and guide optimization decisions.
- Usage:
diagnostics = TrainingDiagnostics(model, track_every=10)
# During training loop: loss.backward() diagnostics.update(model, loss.item(), batch_idx) optimizer.step()
# Generate reports: summary = diagnostics.get_summary() diagnostics.plot_diagnostics()
- get_gradient_histogram_data() Dict[str, numpy.ndarray][source]
Get gradient histograms for current step.
- Returns:
Dictionary mapping layer names to gradient histograms
- get_summary() Dict[str, Any][source]
Generate a comprehensive diagnostic summary.
- Returns:
Dictionary containing current diagnostic state and recommendations
Modules
Lightweight optimizer to infer output edges from intermediate GSNN node activations. |
|
|
Comprehensive training diagnostics for monitoring model optimization. |