gsnn.gsnn.optim.GradDiagnostics

Classes

Any(*args, **kwargs)

Special type indicating an unconstrained type.

GradDiagnostics([window_size, verbose])

Gradient Diagnostics for Vanishing Gradient Analysis in PyTorch Models.

class gsnn.gsnn.optim.GradDiagnostics.GradDiagnostics(window_size: int = 100, verbose: bool = True)[source]

Bases: object

Gradient Diagnostics for Vanishing Gradient Analysis in PyTorch Models.

Tracks and summarizes gradient vanishing effects over training steps. Useful for diagnosing vanishing gradient problems in deep or recurrent networks.

Usage:

diag = GradDiagnostics(window_size=100, verbose=True) # During training loop, after loss.backward(): diag.update(model, loss, step) # To get the latest summary: summary = diag.get_summary() # To reset history: diag.reset()

analyze(model: torch.nn.Module, threshold: float = 1e-06) Dict[str, Any][source]

Analyze and summarize gradient vanishing effects in the model.

For each parameter tensor in the model, computes the mean and maximum absolute gradient after a backward pass, and the fraction of elements with very small gradients (|grad| < threshold). If a large fraction of parameters have near-zero gradients, it may indicate vanishing gradients.

Parameters:
  • model (torch.nn.Module) – Model to analyze. Should have gradients computed (after backward).

  • threshold (float) – The absolute gradient value below which a gradient is considered ‘vanished’.

Returns:

A dictionary containing per-parameter statistics and an overall summary, including:
  • ’per_layer’: {layer_name: {‘mean_abs_grad’, ‘max_abs_grad’, ‘vanishing_frac’}}

  • ’overall’: {‘avg_vanishing_frac’, ‘layers_with_high_vanishing’, ‘threshold’}

Return type:

dict

get_summary() Dict[str, Any][source]

Get the most recent gradient vanishing summary, including step and loss.

Returns:

The latest summary dictionary, or None if no history.

Return type:

dict

plot_diagnostics(step: int = None)[source]

Plot the vanishing gradient fraction by layer number for different parameter types. X-axis: layer number (extracted from parameter names like ‘ResBlocks.0.lin_in.weight’) Y-axis: vanishing gradient fraction (|grad| < threshold) for that step. Different parameter types (lin_in, lin_out, norm, etc.) are shown as different colored lines.

Parameters:

step (int, optional) – Index of the step in history to plot. If None, uses the most recent step.

plot_gradient_magnitude_by_layer(step: int = None)[source]

Plot the mean absolute gradient magnitude by layer number for different parameter types. X-axis: layer number, Y-axis: mean absolute gradient magnitude. Different parameter types shown as different colored lines.

Parameters:

step (int, optional) – Index of the step in history to plot. If None, uses the most recent step.

plot_gradient_ratio_heatmap(steps_to_show: int = 10)[source]

Plot a heatmap showing vanishing gradient fraction across layers and recent training steps. X-axis: training step, Y-axis: layer number, Color: vanishing fraction.

Parameters:

steps_to_show (int) – Number of recent steps to include in heatmap.

plot_summary_statistics(step: int = None)[source]

Plot summary statistics: mean gradient magnitude, max gradient magnitude, and vanishing fraction across all parameters for a given step.

Parameters:

step (int, optional) – Index of the step in history to plot. If None, uses the most recent step.

plot_vanishing_over_time(layers_to_show: list = None)[source]

Plot vanishing gradient fraction over time for specific layers. X-axis: training step, Y-axis: vanishing gradient fraction. Different layers shown as different colored lines.

Parameters:

layers_to_show (list, optional) – List of layer numbers to show. If None, shows first 5 layers.

reset()[source]

Reset the diagnostic history and step counter.

update(model: torch.nn.Module, loss: float, step: Optional[int] = None, threshold: float = 1e-06)[source]

Update diagnostics with the current model gradients and loss. Stores the latest summary in the history (rolling window), along with step and loss.

Parameters:
  • model (torch.nn.Module) – Model to analyze. Should have gradients computed (after backward).

  • loss (float) – Current loss value.

  • step (int, optional) – Step counter. If None, increments internal step.

  • threshold (float) – Threshold for vanishing gradient detection.