Source code for gsnn.optim.GradDiagnostics

import torch
import numpy as np
from typing import Dict, Any, Optional
import matplotlib.pyplot as plt

[docs]class GradDiagnostics: """ Gradient Diagnostics for Vanishing Gradient Analysis in PyTorch Models. Tracks and summarizes gradient vanishing effects over training steps. Useful for diagnosing vanishing gradient problems in deep or recurrent networks. Usage: diag = GradDiagnostics(window_size=100, verbose=True) # During training loop, after loss.backward(): diag.update(model, loss, step) # To get the latest summary: summary = diag.get_summary() # To reset history: diag.reset() """ def __init__(self, window_size: int = 100, verbose: bool = True): """ Args: window_size (int): Number of recent steps to keep in history. verbose (bool): If True, print warnings for high vanishing fraction. """ self.window_size = window_size self.verbose = verbose self.history = [] # List of per-step summaries self.step = 0
[docs] def analyze(self, model: torch.nn.Module, threshold: float = 1e-6) -> Dict[str, Any]: """ Analyze and summarize gradient vanishing effects in the model. For each parameter tensor in the model, computes the mean and maximum absolute gradient after a backward pass, and the fraction of elements with very small gradients (|grad| < threshold). If a large fraction of parameters have near-zero gradients, it may indicate vanishing gradients. Args: model (torch.nn.Module): Model to analyze. Should have gradients computed (after backward). threshold (float): The absolute gradient value below which a gradient is considered 'vanished'. Returns: dict: A dictionary containing per-parameter statistics and an overall summary, including: - 'per_layer': {layer_name: {'mean_abs_grad', 'max_abs_grad', 'vanishing_frac'}} - 'overall': {'avg_vanishing_frac', 'layers_with_high_vanishing', 'threshold'} """ grad_stats = {} vanishing_fracs = [] layers_with_high_vanishing = [] for name, param in model.named_parameters(): if param.grad is not None: grad = param.grad.detach().abs().cpu().view(-1) mean_abs_grad = grad.mean().item() max_abs_grad = grad.max().item() vanishing_frac = (grad < threshold).float().mean().item() grad_stats[name] = { 'mean_abs_grad': mean_abs_grad, 'max_abs_grad': max_abs_grad, 'vanishing_frac': vanishing_frac } vanishing_fracs.append(vanishing_frac) if vanishing_frac > 0.99: layers_with_high_vanishing.append(name) else: grad_stats[name] = { 'mean_abs_grad': None, 'max_abs_grad': None, 'vanishing_frac': None } avg_vanishing_frac = float(np.mean(vanishing_fracs)) if vanishing_fracs else None summary = { 'per_layer': grad_stats, 'overall': { 'avg_vanishing_frac': avg_vanishing_frac, 'layers_with_high_vanishing': layers_with_high_vanishing, 'threshold': threshold } } if self.verbose and layers_with_high_vanishing: print(f"[GradDiagnostics] Warning: Layers with >99% vanishing gradients: {layers_with_high_vanishing}") return summary
[docs] def update(self, model: torch.nn.Module, loss: float, step: Optional[int] = None, threshold: float = 1e-6): """ Update diagnostics with the current model gradients and loss. Stores the latest summary in the history (rolling window), along with step and loss. Args: model (torch.nn.Module): Model to analyze. Should have gradients computed (after backward). loss (float): Current loss value. step (int, optional): Step counter. If None, increments internal step. threshold (float): Threshold for vanishing gradient detection. """ if step is not None: self.step = step else: self.step += 1 summary = self.analyze(model, threshold=threshold) record = { 'step': self.step, 'loss': loss, 'grad_summary': summary } self.history.append(record) if len(self.history) > self.window_size: self.history.pop(0)
[docs] def get_summary(self) -> Dict[str, Any]: """ Get the most recent gradient vanishing summary, including step and loss. Returns: dict: The latest summary dictionary, or None if no history. """ if self.history: return self.history[-1] else: return None
[docs] def reset(self): """ Reset the diagnostic history and step counter. """ self.history.clear() self.step = 0
[docs] def plot_diagnostics(self, step: int = None): """ Plot the vanishing gradient fraction by layer number for different parameter types. X-axis: layer number (extracted from parameter names like 'ResBlocks.0.lin_in.weight') Y-axis: vanishing gradient fraction (|grad| < threshold) for that step. Different parameter types (lin_in, lin_out, norm, etc.) are shown as different colored lines. Args: step (int, optional): Index of the step in history to plot. If None, uses the most recent step. """ if not self.history: print("No diagnostic history to plot.") return # Determine which record to plot if step is None: record = self.history[-1] step_label = record['step'] else: if step < 0 or step >= len(self.history): print(f"Step {step} out of range, using most recent step.") record = self.history[-1] step_label = record['step'] else: record = self.history[step] step_label = record['step'] grad_summary = record['grad_summary'] # Parse parameter names to extract layer numbers and types import re layer_data = {} # {param_type: {layer_num: vanishing_frac}} for param_name, stats in grad_summary['per_layer'].items(): vanishing_frac = stats['vanishing_frac'] if vanishing_frac is None: continue # Extract layer number (e.g., "ResBlocks.0.lin_in.weight" -> layer 0) layer_match = re.search(r'ResBlocks\.(\d+)', param_name) if not layer_match: continue layer_num = int(layer_match.group(1)) # Extract parameter type (lin_in, lin_out, norm, etc.) if 'lin_in' in param_name: param_type = 'lin_in' elif 'lin_out' in param_name: param_type = 'lin_out' elif 'norm' in param_name: param_type = 'norm' elif 'nonlin' in param_name: param_type = 'nonlin' elif 'residual_weight' in param_name: param_type = 'residual_weight' else: param_type = 'other' if param_type not in layer_data: layer_data[param_type] = {} layer_data[param_type][layer_num] = vanishing_frac # Plot plt.figure(figsize=(10, 6)) colors = plt.cm.tab10(np.linspace(0, 1, len(layer_data))) for i, (param_type, layer_fracs) in enumerate(layer_data.items()): layers = sorted(layer_fracs.keys()) fracs = [layer_fracs[layer] for layer in layers] plt.plot(layers, fracs, marker='o', label=param_type, color=colors[i]) plt.xlabel('Layer') plt.ylabel('Vanishing gradient fraction (|grad| < threshold)') plt.title(f'Vanishing Gradient Fraction by Layer (Step {step_label})') plt.legend(title='Parameter Type', loc='upper right') plt.grid(True, alpha=0.3) plt.tight_layout() plt.show()
[docs] def plot_gradient_magnitude_by_layer(self, step: int = None): """ Plot the mean absolute gradient magnitude by layer number for different parameter types. X-axis: layer number, Y-axis: mean absolute gradient magnitude. Different parameter types shown as different colored lines. Args: step (int, optional): Index of the step in history to plot. If None, uses the most recent step. """ if not self.history: print("No diagnostic history to plot.") return # Determine which record to plot if step is None: record = self.history[-1] step_label = record['step'] else: if step < 0 or step >= len(self.history): print(f"Step {step} out of range, using most recent step.") record = self.history[-1] step_label = record['step'] else: record = self.history[step] step_label = record['step'] grad_summary = record['grad_summary'] # Parse parameter names to extract layer numbers and types import re layer_data = {} # {param_type: {layer_num: mean_abs_grad}} for param_name, stats in grad_summary['per_layer'].items(): mean_abs_grad = stats['mean_abs_grad'] if mean_abs_grad is None: continue # Extract layer number layer_match = re.search(r'ResBlocks\.(\d+)', param_name) if not layer_match: continue layer_num = int(layer_match.group(1)) # Extract parameter type if 'lin_in' in param_name: param_type = 'lin_in' elif 'lin_out' in param_name: param_type = 'lin_out' elif 'norm' in param_name: param_type = 'norm' elif 'nonlin' in param_name: param_type = 'nonlin' elif 'residual_weight' in param_name: param_type = 'residual_weight' else: param_type = 'other' if param_type not in layer_data: layer_data[param_type] = {} layer_data[param_type][layer_num] = mean_abs_grad # Plot plt.figure(figsize=(10, 6)) colors = plt.cm.tab10(np.linspace(0, 1, len(layer_data))) for i, (param_type, layer_grads) in enumerate(layer_data.items()): layers = sorted(layer_grads.keys()) grads = [layer_grads[layer] for layer in layers] plt.semilogy(layers, grads, marker='o', label=param_type, color=colors[i]) plt.xlabel('Layer') plt.ylabel('Mean Absolute Gradient Magnitude (log scale)') plt.title(f'Gradient Magnitude by Layer (Step {step_label})') plt.legend(title='Parameter Type', loc='upper right') plt.grid(True, alpha=0.3) plt.tight_layout() plt.show()
[docs] def plot_vanishing_over_time(self, layers_to_show: list = None): """ Plot vanishing gradient fraction over time for specific layers. X-axis: training step, Y-axis: vanishing gradient fraction. Different layers shown as different colored lines. Args: layers_to_show (list, optional): List of layer numbers to show. If None, shows first 5 layers. """ if not self.history: print("No diagnostic history to plot.") return if layers_to_show is None: layers_to_show = [0, 1, 2, 3, 4] # Default to first 5 layers # Collect data over time import re steps = [] layer_data = {layer: [] for layer in layers_to_show} for record in self.history: steps.append(record['step']) grad_summary = record['grad_summary'] # Initialize layer values for this step layer_values = {layer: None for layer in layers_to_show} for param_name, stats in grad_summary['per_layer'].items(): vanishing_frac = stats['vanishing_frac'] if vanishing_frac is None: continue # Extract layer number layer_match = re.search(r'ResBlocks\.(\d+)', param_name) if not layer_match: continue layer_num = int(layer_match.group(1)) if layer_num in layers_to_show and 'lin_in' in param_name: # Use lin_in as representative layer_values[layer_num] = vanishing_frac # Append values for each layer for layer in layers_to_show: layer_data[layer].append(layer_values[layer]) # Plot plt.figure(figsize=(12, 6)) colors = plt.cm.viridis(np.linspace(0, 1, len(layers_to_show))) for i, layer in enumerate(layers_to_show): valid_indices = [j for j, val in enumerate(layer_data[layer]) if val is not None] valid_steps = [steps[j] for j in valid_indices] valid_values = [layer_data[layer][j] for j in valid_indices] if valid_values: plt.plot(valid_steps, valid_values, marker='o', label=f'Layer {layer}', color=colors[i], alpha=0.8) plt.xlabel('Training Step') plt.ylabel('Vanishing Gradient Fraction') plt.title('Vanishing Gradients Over Time by Layer') plt.legend(title='Layer', loc='upper right') plt.grid(True, alpha=0.3) plt.tight_layout() plt.show()
[docs] def plot_gradient_ratio_heatmap(self, steps_to_show: int = 10): """ Plot a heatmap showing vanishing gradient fraction across layers and recent training steps. X-axis: training step, Y-axis: layer number, Color: vanishing fraction. Args: steps_to_show (int): Number of recent steps to include in heatmap. """ if not self.history: print("No diagnostic history to plot.") return # Get recent steps recent_history = self.history[-steps_to_show:] if len(recent_history) < 2: print("Not enough history for heatmap.") return # Collect layer numbers and steps import re all_layers = set() steps = [] for record in recent_history: steps.append(record['step']) grad_summary = record['grad_summary'] for param_name in grad_summary['per_layer'].keys(): layer_match = re.search(r'ResBlocks\.(\d+)', param_name) if layer_match and 'lin_in' in param_name: # Use lin_in as representative all_layers.add(int(layer_match.group(1))) layers = sorted(all_layers) # Build heatmap data heatmap_data = np.full((len(layers), len(steps)), np.nan) for j, record in enumerate(recent_history): grad_summary = record['grad_summary'] for param_name, stats in grad_summary['per_layer'].items(): vanishing_frac = stats['vanishing_frac'] if vanishing_frac is None: continue layer_match = re.search(r'ResBlocks\.(\d+)', param_name) if layer_match and 'lin_in' in param_name: layer_num = int(layer_match.group(1)) if layer_num in layers: layer_idx = layers.index(layer_num) heatmap_data[layer_idx, j] = vanishing_frac # Plot heatmap plt.figure(figsize=(12, 8)) im = plt.imshow(heatmap_data, cmap='Reds', aspect='auto', interpolation='nearest') plt.xlabel('Training Step') plt.ylabel('Layer Number') plt.title('Vanishing Gradient Fraction Heatmap') # Set ticks plt.xticks(range(len(steps)), steps, rotation=45) plt.yticks(range(len(layers)), layers) # Add colorbar cbar = plt.colorbar(im) cbar.set_label('Vanishing Gradient Fraction') plt.tight_layout() plt.show()
[docs] def plot_summary_statistics(self, step: int = None): """ Plot summary statistics: mean gradient magnitude, max gradient magnitude, and vanishing fraction across all parameters for a given step. Args: step (int, optional): Index of the step in history to plot. If None, uses the most recent step. """ if not self.history: print("No diagnostic history to plot.") return # Determine which record to plot if step is None: record = self.history[-1] step_label = record['step'] else: if step < 0 or step >= len(self.history): print(f"Step {step} out of range, using most recent step.") record = self.history[-1] step_label = record['step'] else: record = self.history[step] step_label = record['step'] grad_summary = record['grad_summary'] # Collect statistics by parameter type param_stats = {} for param_name, stats in grad_summary['per_layer'].items(): # Extract parameter type if 'lin_in' in param_name: param_type = 'lin_in' elif 'lin_out' in param_name: param_type = 'lin_out' elif 'norm' in param_name: param_type = 'norm' elif 'residual_weight' in param_name: param_type = 'residual_weight' else: param_type = 'other' if param_type not in param_stats: param_stats[param_type] = {'mean_grads': [], 'max_grads': [], 'vanishing_fracs': []} if stats['mean_abs_grad'] is not None: param_stats[param_type]['mean_grads'].append(stats['mean_abs_grad']) param_stats[param_type]['max_grads'].append(stats['max_abs_grad']) param_stats[param_type]['vanishing_fracs'].append(stats['vanishing_frac']) # Create subplot fig, axes = plt.subplots(1, 3, figsize=(15, 5)) param_types = list(param_stats.keys()) x = np.arange(len(param_types)) # Mean gradient magnitude mean_vals = [np.mean(param_stats[pt]['mean_grads']) if param_stats[pt]['mean_grads'] else 0 for pt in param_types] axes[0].bar(x, mean_vals) axes[0].set_xlabel('Parameter Type') axes[0].set_ylabel('Mean Abs Gradient') axes[0].set_title('Mean Gradient Magnitude') axes[0].set_xticks(x) axes[0].set_xticklabels(param_types, rotation=45) axes[0].set_yscale('log') # Max gradient magnitude max_vals = [np.mean(param_stats[pt]['max_grads']) if param_stats[pt]['max_grads'] else 0 for pt in param_types] axes[1].bar(x, max_vals) axes[1].set_xlabel('Parameter Type') axes[1].set_ylabel('Mean Max Gradient') axes[1].set_title('Max Gradient Magnitude') axes[1].set_xticks(x) axes[1].set_xticklabels(param_types, rotation=45) axes[1].set_yscale('log') # Vanishing fraction vanishing_vals = [np.mean(param_stats[pt]['vanishing_fracs']) if param_stats[pt]['vanishing_fracs'] else 0 for pt in param_types] axes[2].bar(x, vanishing_vals) axes[2].set_xlabel('Parameter Type') axes[2].set_ylabel('Vanishing Fraction') axes[2].set_title('Vanishing Gradient Fraction') axes[2].set_xticks(x) axes[2].set_xticklabels(param_types, rotation=45) plt.suptitle(f'Gradient Summary Statistics (Step {step_label})') plt.tight_layout() plt.show()