gsnn.gsnn.optim.GradDiagnostics
Classes
|
Special type indicating an unconstrained type. |
|
Gradient Diagnostics for Vanishing Gradient Analysis in PyTorch Models. |
- class gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics(window_size: int = 100, verbose: bool = True)[source]
Bases:
objectGradient Diagnostics for Vanishing Gradient Analysis in PyTorch Models.
Tracks and summarizes gradient vanishing effects over training steps. Useful for diagnosing vanishing gradient problems in deep or recurrent networks.
- Usage:
diag = GradDiagnostics(window_size=100, verbose=True) # During training loop, after loss.backward(): diag.update(model, loss, step) # To get the latest summary: summary = diag.get_summary() # To reset history: diag.reset()
- analyze(model: torch.nn.Module, threshold: float = 1e-06) Dict[str, Any][source]
Analyze and summarize gradient vanishing effects in the model.
For each parameter tensor in the model, computes the mean and maximum absolute gradient after a backward pass, and the fraction of elements with very small gradients (|grad| < threshold). If a large fraction of parameters have near-zero gradients, it may indicate vanishing gradients.
- Parameters:
model (torch.nn.Module) – Model to analyze. Should have gradients computed (after backward).
threshold (float) – The absolute gradient value below which a gradient is considered ‘vanished’.
- Returns:
- A dictionary containing per-parameter statistics and an overall summary, including:
’per_layer’: {layer_name: {‘mean_abs_grad’, ‘max_abs_grad’, ‘vanishing_frac’}}
’overall’: {‘avg_vanishing_frac’, ‘layers_with_high_vanishing’, ‘threshold’}
- Return type:
- get_summary() Dict[str, Any][source]
Get the most recent gradient vanishing summary, including step and loss.
- Returns:
The latest summary dictionary, or None if no history.
- Return type:
- plot_diagnostics(step: int = None)[source]
Plot the vanishing gradient fraction by layer number for different parameter types. X-axis: layer number (extracted from parameter names like ‘ResBlocks.0.lin_in.weight’) Y-axis: vanishing gradient fraction (|grad| < threshold) for that step. Different parameter types (lin_in, lin_out, norm, etc.) are shown as different colored lines.
- Parameters:
step (int, optional) – Index of the step in history to plot. If None, uses the most recent step.
- plot_gradient_magnitude_by_layer(step: int = None)[source]
Plot the mean absolute gradient magnitude by layer number for different parameter types. X-axis: layer number, Y-axis: mean absolute gradient magnitude. Different parameter types shown as different colored lines.
- Parameters:
step (int, optional) – Index of the step in history to plot. If None, uses the most recent step.
- plot_gradient_ratio_heatmap(steps_to_show: int = 10)[source]
Plot a heatmap showing vanishing gradient fraction across layers and recent training steps. X-axis: training step, Y-axis: layer number, Color: vanishing fraction.
- Parameters:
steps_to_show (int) – Number of recent steps to include in heatmap.
- plot_summary_statistics(step: int = None)[source]
Plot summary statistics: mean gradient magnitude, max gradient magnitude, and vanishing fraction across all parameters for a given step.
- Parameters:
step (int, optional) – Index of the step in history to plot. If None, uses the most recent step.
- plot_vanishing_over_time(layers_to_show: list = None)[source]
Plot vanishing gradient fraction over time for specific layers. X-axis: training step, Y-axis: vanishing gradient fraction. Different layers shown as different colored lines.
- Parameters:
layers_to_show (list, optional) – List of layer numbers to show. If None, shows first 5 layers.
- update(model: torch.nn.Module, loss: float, step: Optional[int] = None, threshold: float = 1e-06)[source]
Update diagnostics with the current model gradients and loss. Stores the latest summary in the history (rolling window), along with step and loss.
- Parameters:
model (torch.nn.Module) – Model to analyze. Should have gradients computed (after backward).
loss (float) – Current loss value.
step (int, optional) – Step counter. If None, increments internal step.
threshold (float) – Threshold for vanishing gradient detection.