import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, deque
from typing import Dict, List, Optional, Tuple, Any
import warnings
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import copy
[docs]class TrainingDiagnostics:
"""
Comprehensive training diagnostics for monitoring model optimization.
Tracks gradient flow, activation patterns, weight spectral properties,
and basic curvature metrics to help diagnose training issues and guide
optimization decisions.
Usage:
diagnostics = TrainingDiagnostics(model, track_every=10)
# During training loop:
loss.backward()
diagnostics.update(model, loss.item(), batch_idx)
optimizer.step()
# Generate reports:
summary = diagnostics.get_summary()
diagnostics.plot_diagnostics()
"""
[docs] def __init__(self, model: nn.Module, track_every: int = 10,
window_size: int = 100, track_activations: bool = True,
track_weights: bool = True, track_gradients: bool = True,
track_curvature: bool = False, verbose: bool = True):
"""
Initialize training diagnostics tracker.
Args:
model: PyTorch model to monitor
track_every: Update diagnostics every N steps
window_size: Size of rolling window for statistics
track_activations: Whether to track activation statistics
track_weights: Whether to track weight spectral properties
track_gradients: Whether to track gradient flow metrics
track_curvature: Whether to track curvature (expensive!)
verbose: Whether to print diagnostic warnings
"""
self.model = model
self.track_every = track_every
self.window_size = window_size
self.track_activations = track_activations
self.track_weights = track_weights
self.track_gradients = track_gradients
self.track_curvature = track_curvature
self.verbose = verbose
# Storage for metrics
self.step = 0
self.losses = deque(maxlen=window_size)
# Gradient flow diagnostics
self.grad_norms = defaultdict(list)
self.weight_norms = defaultdict(list)
self.grad_weight_ratios = defaultdict(list)
self.grad_noise_scale = []
self.grad_cosine_sim = []
# Activation diagnostics
self.dead_neuron_rates = defaultdict(list)
self.activation_saturations = defaultdict(list)
self.feature_variance = defaultdict(list)
# Weight spectral diagnostics
self.weight_singular_values = defaultdict(list)
self.condition_numbers = defaultdict(list)
# Curvature diagnostics
self.hessian_traces = []
self.sharpness_estimates = []
# Activation hooks
self.activation_hooks = []
self.activations = {}
if self.track_activations:
self._register_activation_hooks()
# Store previous gradients for noise computation
self.prev_gradients = None
def _register_activation_hooks(self):
"""Register forward hooks to capture activations."""
def make_hook(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
self.activations[name] = output.detach()
elif isinstance(output, (list, tuple)) and len(output) > 0:
self.activations[name] = output[0].detach()
return hook
for name, module in self.model.named_modules():
if len(list(module.children())) == 0: # Leaf modules only
handle = module.register_forward_hook(make_hook(name))
self.activation_hooks.append(handle)
[docs] def update(self, model: nn.Module, loss: float, step: Optional[int] = None):
"""
Update diagnostics with current training state.
Args:
model: Current model state
loss: Current loss value
step: Optional step counter (uses internal if None)
"""
if step is not None:
self.step = step
else:
self.step += 1
self.losses.append(loss)
# Only compute expensive diagnostics every N steps
if self.step % self.track_every != 0:
return
with torch.no_grad():
if self.track_gradients:
self._update_gradient_diagnostics(model)
if self.track_activations:
self._update_activation_diagnostics()
if self.track_weights:
self._update_weight_diagnostics(model)
if self.track_curvature:
self._update_curvature_diagnostics(model)
def _update_gradient_diagnostics(self, model: nn.Module):
"""Update gradient flow diagnostics."""
gradients = []
for name, param in model.named_parameters():
if param.grad is not None:
grad = param.grad.data
weight = param.data
# Gradient and weight norms
grad_norm = torch.norm(grad).item()
weight_norm = torch.norm(weight).item()
self.grad_norms[name].append(grad_norm)
self.weight_norms[name].append(weight_norm)
# Gradient/weight ratio
ratio = grad_norm / (weight_norm + 1e-8)
self.grad_weight_ratios[name].append(ratio)
# Collect gradients for noise analysis
gradients.append(grad.flatten())
# Check for problematic gradients
if self.verbose:
if grad_norm > 10.0:
warnings.warn(f"Large gradient norm in {name}: {grad_norm:.4f}")
if ratio > 1.0:
warnings.warn(f"Large grad/weight ratio in {name}: {ratio:.4f}")
# Gradient noise scale computation
if len(gradients) > 0:
grad_vec = torch.cat(gradients)
if self.prev_gradients is not None:
# Compute gradient noise as variance/mean^2
grad_diff = grad_vec - self.prev_gradients
noise = torch.var(grad_diff) / (torch.mean(grad_vec)**2 + 1e-8)
self.grad_noise_scale.append(noise.item())
# Gradient cosine similarity
cos_sim = torch.cosine_similarity(grad_vec.unsqueeze(0),
self.prev_gradients.unsqueeze(0))
self.grad_cosine_sim.append(cos_sim.item())
self.prev_gradients = grad_vec.clone()
def _update_activation_diagnostics(self):
"""Update activation space diagnostics."""
for name, activation in self.activations.items():
if activation.numel() == 0:
continue
# Dead neuron rate (for ReLU-like activations)
dead_rate = (activation <= 0).float().mean().item()
self.dead_neuron_rates[name].append(dead_rate)
# Activation saturation (for sigmoid/tanh)
if activation.min() >= 0 and activation.max() <= 1: # Likely sigmoid
saturation = ((activation < 0.01) | (activation > 0.99)).float().mean().item()
self.activation_saturations[name].append(saturation)
elif activation.min() >= -1 and activation.max() <= 1: # Likely tanh
saturation = ((activation < -0.99) | (activation > 0.99)).float().mean().item()
self.activation_saturations[name].append(saturation)
# Feature variance across batch (detect feature collapse)
if len(activation.shape) >= 2:
feature_var = activation.var(dim=0).mean().item()
self.feature_variance[name].append(feature_var)
# Clear activations for next iteration
self.activations.clear()
def _update_weight_diagnostics(self, model: nn.Module):
"""Update weight spectral diagnostics."""
for name, param in model.named_parameters():
if len(param.shape) >= 2: # Only for weight matrices
weight = param.data
# Compute singular values
try:
U, S, V = torch.svd(weight.flatten(1))
singular_vals = S.cpu().numpy()
self.weight_singular_values[name].append({
'max': float(singular_vals.max()),
'min': float(singular_vals[singular_vals > 1e-8].min()) if len(singular_vals[singular_vals > 1e-8]) > 0 else 1e-8,
'mean': float(singular_vals.mean()),
'std': float(singular_vals.std())
})
# Condition number
cond_num = float(singular_vals.max() / (singular_vals[singular_vals > 1e-8].min() + 1e-8))
self.condition_numbers[name].append(cond_num)
# Check for problematic spectra
if self.verbose:
if cond_num > 1000:
warnings.warn(f"High condition number in {name}: {cond_num:.2f}")
if singular_vals.max() > 10:
warnings.warn(f"Large singular values in {name}: max={singular_vals.max():.2f}")
except Exception as e:
if self.verbose:
warnings.warn(f"SVD failed for {name}: {e}")
def _update_curvature_diagnostics(self, model: nn.Module):
"""Update curvature diagnostics (expensive!)."""
try:
# Hutchinson trace estimator for Hessian trace
params = list(model.parameters())
z = torch.randn_like(parameters_to_vector(params))
# This is a simplified version - in practice you'd need the loss function
# For now, just store a placeholder
self.hessian_traces.append(0.0)
except Exception as e:
if self.verbose:
warnings.warn(f"Curvature computation failed: {e}")
[docs] def get_summary(self) -> Dict[str, Any]:
"""
Generate a comprehensive diagnostic summary.
Returns:
Dictionary containing current diagnostic state and recommendations
"""
summary = {
'step': self.step,
'current_loss': self.losses[-1] if self.losses else None,
'gradient_health': {},
'activation_health': {},
'weight_health': {},
'recommendations': []
}
# Gradient health summary
if self.grad_weight_ratios:
latest_ratios = {name: vals[-1] for name, vals in self.grad_weight_ratios.items() if vals}
avg_ratio = np.mean(list(latest_ratios.values()))
summary['gradient_health'] = {
'avg_grad_weight_ratio': avg_ratio,
'problematic_layers': [name for name, ratio in latest_ratios.items() if ratio > 1.0],
'gradient_noise': self.grad_noise_scale[-1] if self.grad_noise_scale else None,
'gradient_alignment': self.grad_cosine_sim[-1] if self.grad_cosine_sim else None
}
# Generate recommendations
if avg_ratio > 1.0:
summary['recommendations'].append("High gradient/weight ratios detected. Consider gradient clipping or learning rate reduction.")
if self.grad_noise_scale and self.grad_noise_scale[-1] > 1.0:
summary['recommendations'].append("High gradient noise. Consider increasing batch size.")
if self.grad_cosine_sim and self.grad_cosine_sim[-1] < 0.1:
summary['recommendations'].append("Low gradient alignment. Training may be unstable.")
# Activation health summary
if self.dead_neuron_rates:
latest_dead_rates = {name: vals[-1] for name, vals in self.dead_neuron_rates.items() if vals}
avg_dead_rate = np.mean(list(latest_dead_rates.values()))
summary['activation_health'] = {
'avg_dead_neuron_rate': avg_dead_rate,
'layers_with_dead_neurons': [name for name, rate in latest_dead_rates.items() if rate > 0.5]
}
if avg_dead_rate > 0.5:
summary['recommendations'].append("High dead neuron rate. Consider LeakyReLU, better initialization, or lower learning rates.")
# Weight health summary
if self.condition_numbers:
latest_cond_nums = {name: vals[-1] for name, vals in self.condition_numbers.items() if vals}
avg_cond_num = np.mean(list(latest_cond_nums.values()))
summary['weight_health'] = {
'avg_condition_number': avg_cond_num,
'ill_conditioned_layers': [name for name, cond in latest_cond_nums.items() if cond > 1000]
}
if avg_cond_num > 1000:
summary['recommendations'].append("High condition numbers detected. Consider spectral normalization or better initialization.")
return summary
[docs] def plot_diagnostics(self, save_path: Optional[str] = None, figsize: Tuple[int, int] = (15, 12)):
"""
Generate comprehensive diagnostic plots.
Args:
save_path: Optional path to save the plot
figsize: Figure size tuple
"""
fig, axes = plt.subplots(3, 3, figsize=figsize)
fig.suptitle(f'Training Diagnostics (Step {self.step})', fontsize=16)
# Loss curve
if self.losses:
axes[0, 0].plot(list(self.losses))
axes[0, 0].set_title('Loss Curve')
axes[0, 0].set_xlabel('Recent Steps')
axes[0, 0].set_ylabel('Loss')
# Gradient/weight ratios
if self.grad_weight_ratios:
for name, ratios in self.grad_weight_ratios.items():
axes[0, 1].plot(ratios[-min(50, len(ratios)):], label=name[:15], alpha=0.7)
axes[0, 1].set_title('Gradient/Weight Ratios')
axes[0, 1].set_ylabel('Ratio')
axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
axes[0, 1].axhline(y=1.0, color='r', linestyle='--', alpha=0.5)
# Gradient noise scale
if self.grad_noise_scale:
axes[0, 2].plot(self.grad_noise_scale)
axes[0, 2].set_title('Gradient Noise Scale')
axes[0, 2].set_ylabel('Noise Scale')
# Dead neuron rates
if self.dead_neuron_rates:
for name, rates in self.dead_neuron_rates.items():
axes[1, 0].plot(rates[-min(50, len(rates)):], label=name[:15], alpha=0.7)
axes[1, 0].set_title('Dead Neuron Rates')
axes[1, 0].set_ylabel('Dead Rate')
axes[1, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Condition numbers
if self.condition_numbers:
for name, conds in self.condition_numbers.items():
axes[1, 1].semilogy(conds[-min(50, len(conds)):], label=name[:15], alpha=0.7)
axes[1, 1].set_title('Condition Numbers')
axes[1, 1].set_ylabel('Condition Number')
axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Feature variance (collapse detection)
if self.feature_variance:
for name, variances in self.feature_variance.items():
axes[1, 2].plot(variances[-min(50, len(variances)):], label=name[:15], alpha=0.7)
axes[1, 2].set_title('Feature Variance')
axes[1, 2].set_ylabel('Variance')
axes[1, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Gradient cosine similarity
if self.grad_cosine_sim:
axes[2, 0].plot(self.grad_cosine_sim)
axes[2, 0].set_title('Gradient Cosine Similarity')
axes[2, 0].set_ylabel('Cosine Similarity')
axes[2, 0].axhline(y=0.1, color='r', linestyle='--', alpha=0.5)
# Weight spectrum overview (latest)
if self.weight_singular_values:
layer_names = []
max_vals = []
min_vals = []
for name, sv_list in self.weight_singular_values.items():
if sv_list:
layer_names.append(name[:10])
max_vals.append(sv_list[-1]['max'])
min_vals.append(sv_list[-1]['min'])
if layer_names:
x_pos = np.arange(len(layer_names))
axes[2, 1].bar(x_pos, max_vals, alpha=0.7, label='Max SV')
axes[2, 1].bar(x_pos, min_vals, alpha=0.7, label='Min SV')
axes[2, 1].set_title('Weight Singular Values')
axes[2, 1].set_xticks(x_pos)
axes[2, 1].set_xticklabels(layer_names, rotation=45)
axes[2, 1].legend()
# Activation saturation
if self.activation_saturations:
for name, sats in self.activation_saturations.items():
axes[2, 2].plot(sats[-min(50, len(sats)):], label=name[:15], alpha=0.7)
axes[2, 2].set_title('Activation Saturation')
axes[2, 2].set_ylabel('Saturation Rate')
axes[2, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
[docs] def get_gradient_histogram_data(self) -> Dict[str, np.ndarray]:
"""
Get gradient histograms for current step.
Returns:
Dictionary mapping layer names to gradient histograms
"""
histograms = {}
for name, param in self.model.named_parameters():
if param.grad is not None:
grad_flat = param.grad.data.flatten().cpu().numpy()
histograms[name] = grad_flat
return histograms
[docs] def reset(self):
"""Reset all diagnostic tracking."""
self.step = 0
self.losses.clear()
for attr_name in ['grad_norms', 'weight_norms', 'grad_weight_ratios',
'dead_neuron_rates', 'activation_saturations', 'feature_variance',
'weight_singular_values', 'condition_numbers']:
getattr(self, attr_name).clear()
for attr_name in ['grad_noise_scale', 'grad_cosine_sim', 'hessian_traces', 'sharpness_estimates']:
getattr(self, attr_name).clear()
def __del__(self):
"""Cleanup activation hooks."""
for handle in self.activation_hooks:
handle.remove()