Source code for gsnn.interpret.ContrastiveGSNNExplainer

import copy
from typing import Union, List

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import r2_score


[docs]class ContrastiveGSNNExplainer: r"""Edge/node mask optimiser for *contrastive* explanations. This explainer learns a binary mask *m∈{0,1}^{E|N}* that maximises fidelity between the **prediction difference** on the masked graph and the difference on the full graph, while simultaneously penalising mask size:: Δf(m) = f(x₁; m)[target_idx] − f(x₂; m)[target_idx] (multivariate) L = MSE(Δf(m), Δf(1)) # over all B×T elements + β max(0, ‖m‖₁ − free_elements) − λ H(m) (optional entropy term) Here *m* is obtained via a differentiable Gumbel-Softmax relaxation so the optimisation can be performed with vanilla back-prop. After convergence the importance score is the softmax probability *p_i = P(m_i=1)*. Interpretation -------------- * ``score_i → 1`` element i is essential for reproducing the prediction difference. * ``score_i → 0`` element i can be removed without affecting the difference. Parameters ---------- model : torch.nn.Module Trained GSNN model (its parameters are *frozen* during explanation). data : torch_geometric.data.Data Graph data object (only metadata are used). ignore_cuda : bool, optional (default=False) Force CPU even if CUDA is available. gumbel_softmax : bool, optional (default=True) Use the Gumbel-Softmax re-parameterisation; otherwise plain Softmax. hard : bool, optional (default=False) Use the straight-through estimator to obtain discrete masks at test time while keeping gradients continuous. tau0 : float, optional (default=3.0) Initial temperature for the (hard) Gumbel-Softmax. min_tau : float, optional (default=0.5) Minimum temperature reached after exponential decay. prior : float, optional (default=1.0) Initial bias added to the positive/negative logits. iters : int, optional (default=250) Number of optimisation steps. lr : float, optional (default=1e-2) Learning rate for the optimiser. weight_decay : float, optional (default=1e-5) Weight decay applied to the mask logits. free_edges : int, optional (default=0) Number of elements allowed before the sparsity penalty activates. beta : float, optional (default=1.0) Coefficient of the sparsity term. entropy : float, optional (default=0.0) Strength of the entropy bonus (encourages exploration). verbose : bool, optional (default=True) Print progress information during optimisation. Example ------- >>> explainer = ContrastiveGSNNExplainer(model, data, iters=400, beta=5) >>> # Edge-level attributions >>> edge_df = explainer.explain(x1, x2, target_idx=0, target='edge') >>> edge_df.sort_values('score', ascending=False).head() >>> # Node-level attributions >>> node_df = explainer.explain(x1, x2, target_idx=0, target='node') >>> node_df.sort_values('score', ascending=False).head() """ def __init__( self, model, data, ignore_cuda: bool = False, gumbel_softmax: bool = True, hard: bool = False, tau0: float = 3.0, min_tau: float = 0.5, prior: float = 1.0, iters: int = 250, lr: float = 1e-2, weight_decay: float = 1e-5, free_edges: int = 0, beta: float = 1.0, verbose: bool = True, optimizer=torch.optim.Adam, entropy: float = 0.0, ) -> None: """ Contrastive version of GSNNExplainer for explaining prediction differences. Adapted from the methods presented in `GNNExplainer` (https://arxiv.org/abs/1903.03894). """ self.free_edges = free_edges self.iters = iters self.lr = lr self.weight_decay = weight_decay self.beta = beta self.verbose = verbose self.optimizer = optimizer self.gumbel_softmax = gumbel_softmax self.prior = prior self.hard = hard self.min_tau = min_tau self.tau0 = tau0 self.data = data self.device = 'cuda' if (torch.cuda.is_available() and not ignore_cuda) else 'cpu' self.entropy = entropy model = copy.deepcopy(model) model = model.eval() model = model.to(self.device) # freeze model parameters for p in model.parameters(): p.requires_grad = False self.model = model self.E = model.edge_index.size(1) self.N = model.num_nodes
[docs] def explain( self, x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], *, return_weights: bool = False, target: str = 'edge', ) -> pd.DataFrame: """Compute attributions for *f(x₁) − f(x₂)*. Initializes and runs gradient descent to select a minimal subset of elements that preserve the prediction difference between x1 and x2. When given multiple pairs (batch), learns ONE mask that works well across ALL pairs by treating the differences as a multi-output objective. This is much faster than per-sample optimization. Parameters ---------- x1, x2 : torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch) Two input feature tensors. They must have identical batch size. When B > 1, learns a single mask that preserves the prediction difference across all pairs simultaneously. target_idx : int or list[int] Output dimension(s) to explain. If a list is provided the attributions refer to the **sum** of those outputs. return_weights : bool, optional (default=False) Whether to return raw weights along with the DataFrame. target : str, optional (default='edge') Whether to return 'edge' or 'node' level attributions. Returns ------- pd.DataFrame If target='edge': columns ['source', 'target', 'score'] for edge attributions. If target='node': columns ['node', 'score'] for node attributions. """ if target not in ['edge', 'node']: raise ValueError(f"target must be 'edge' or 'node', got '{target}'") if target == 'edge': return self._explain_edges(x1, x2, target_idx, return_weights) else: return self._explain_nodes(x1, x2, target_idx, return_weights)
def _explain_edges( self, x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], return_weights: bool = False, ) -> pd.DataFrame: """Compute edge-level attributions for *f(x₁) − f(x₂)*. Learns ONE mask across all sample pairs by treating the differences as a multi-output objective. This is much faster than per-sample optimization. """ x1, x2 = x1.to(self.device), x2.to(self.device) # Ensure batch dimension if x1.dim() == 1: x1 = x1.unsqueeze(0) if x2.dim() == 1: x2 = x2.unsqueeze(0) B = x1.size(0) # batch size if isinstance(target_idx, int): target_idx = [target_idx] # Initialize edge parameters (single mask for all samples) weights = torch.stack(( self.prior * torch.ones(self.E, dtype=torch.float32, device=self.device), -self.prior * torch.ones(self.E, dtype=torch.float32, device=self.device) ), dim=0) edge_params = torch.nn.Parameter(weights) # Optimizer and loss crit = torch.nn.MSELoss() optim = self.optimizer([edge_params], lr=self.lr, weight_decay=self.weight_decay) # Calculate tau decay rate tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters) # Get target prediction differences for ALL pairs (baseline) - keep as multivariate with torch.no_grad(): pred1_full = self.model(x1)[:, target_idx] # (B, T) pred2_full = self.model(x2)[:, target_idx] # (B, T) target_diffs = pred1_full - pred2_full # (B, T) - multivariate differences if self.verbose: print(f"Batch size: {B}, Target dims: {len(target_idx)}") print(f"Target Δf mean: {target_diffs.mean().item():.6f}, std: {target_diffs.std().item():.6f}") # Optimization loop - learns ONE mask for all pairs for iter in range(self.iters): optim.zero_grad() tau = max(self.tau0 * tau_decay_rate ** iter, self.min_tau) edge_weight, _ = torch.nn.functional.gumbel_softmax(edge_params, dim=0, hard=self.hard, tau=tau) # Broadcast mask to all samples: (1, E) -> used for all B samples edge_mask_batch = edge_weight.view(1, -1).expand(B, -1) # (B, E) # Forward pass for all pairs at once - keep as multivariate pred1 = self.model(x1, edge_mask=edge_mask_batch)[:, target_idx] # (B, T) pred2 = self.model(x2, edge_mask=edge_mask_batch)[:, target_idx] # (B, T) masked_diffs = pred1 - pred2 # (B, T) - multivariate differences # MSE over all B*T elements mse = crit(masked_diffs, target_diffs) edge_probs, _ = torch.nn.functional.softmax(edge_params, dim=0) m = torch.distributions.Bernoulli(probs=edge_probs) ent = m.entropy().mean() loss = mse \ + self.beta * torch.maximum(torch.tensor([0.], device=self.device), edge_weight.sum() - self.free_edges) \ - self.entropy * ent loss.backward() optim.step() if self.verbose: with torch.no_grad(): r2 = r2_score( target_diffs.detach().cpu().numpy().ravel(), masked_diffs.detach().cpu().numpy().ravel() ) if target_diffs.numel() > 1 else -666 print(f'iter: {iter} | loss: {loss.item():.4f} | mse: {mse.item():.4f} | r2: {r2:.3f} | active edges: {(edge_weight > 0.5).sum().item()} / {self.E} | entropy: {ent.item():.4f}', end='\r') # Post-training evaluation if self.verbose: print() with torch.no_grad(): final_edge_probs, _ = torch.nn.functional.softmax(edge_params.data, dim=0) subset_mask = (final_edge_probs > 0.5).float() subset_mask_batch = subset_mask.view(1, -1).expand(B, -1) pred1_sub = self.model(x1, edge_mask=subset_mask_batch)[:, target_idx] # (B, T) pred2_sub = self.model(x2, edge_mask=subset_mask_batch)[:, target_idx] # (B, T) subset_diffs = pred1_sub - pred2_sub # (B, T) subset_mse = torch.nn.functional.mse_loss(subset_diffs, target_diffs).item() subset_r2 = r2_score( target_diffs.detach().cpu().numpy().ravel(), subset_diffs.detach().cpu().numpy().ravel() ) if target_diffs.numel() > 1 else -666 num_selected = (subset_mask > 0.5).sum().item() print("=" * 50) print("POST-TRAINING EVALUATION (edges > 0.5)") print("=" * 50) print(f"Selected edges: {num_selected} / {self.E} ({100 * num_selected / self.E:.1f}%)") print(f"Target Δf mean: {target_diffs.mean().item():.6f}") print(f"Subset Δf mean: {subset_diffs.mean().item():.6f}") print(f"MSE: {subset_mse:.6f}") print(f"R² (across {B}x{len(target_idx)} elements): {subset_r2:.4f}") print("=" * 50) edge_scores, _ = torch.nn.functional.softmax(edge_params.data, dim=0) # Package results - single set of scores for all pairs src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()] result_df = pd.DataFrame({ "source": src, "target": dst, "score": edge_scores.detach().cpu().numpy(), }) if return_weights: return result_df, edge_scores.detach().cpu().numpy() return result_df def _explain_nodes( self, x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], return_weights: bool = False, ) -> pd.DataFrame: """Compute node-level attributions for *f(x₁) − f(x₂)*. Learns ONE mask across all sample pairs by treating the differences as a multi-output objective. This is much faster than per-sample optimization. """ x1, x2 = x1.to(self.device), x2.to(self.device) # Ensure batch dimension if x1.dim() == 1: x1 = x1.unsqueeze(0) if x2.dim() == 1: x2 = x2.unsqueeze(0) B = x1.size(0) # batch size if isinstance(target_idx, int): target_idx = [target_idx] # Initialize node parameters (single mask for all samples) weights = torch.stack(( self.prior * torch.ones(self.N, dtype=torch.float32, device=self.device), -self.prior * torch.ones(self.N, dtype=torch.float32, device=self.device) ), dim=0) node_params = torch.nn.Parameter(weights) # Optimizer and loss crit = torch.nn.MSELoss() optim = self.optimizer([node_params], lr=self.lr, weight_decay=self.weight_decay) # Calculate tau decay rate tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters) # Get target prediction differences for ALL pairs (baseline) - keep as multivariate with torch.no_grad(): pred1_full = self.model(x1)[:, target_idx] # (B, T) pred2_full = self.model(x2)[:, target_idx] # (B, T) target_diffs = pred1_full - pred2_full # (B, T) - multivariate differences if self.verbose: print(f"Batch size: {B}, Target dims: {len(target_idx)}") print(f"Target Δf mean: {target_diffs.mean().item():.6f}, std: {target_diffs.std().item():.6f}") # Optimization loop - learns ONE mask for all pairs for iter in range(self.iters): optim.zero_grad() tau = max(self.tau0 * tau_decay_rate ** iter, self.min_tau) node_weight, _ = torch.nn.functional.gumbel_softmax(node_params, dim=0, hard=self.hard, tau=tau) # Broadcast mask to all samples: (1, N) -> used for all B samples node_mask_batch = node_weight.view(1, -1).expand(B, -1) # (B, N) # Forward pass for all pairs at once - keep as multivariate pred1 = self.model(x1, node_mask=node_mask_batch)[:, target_idx] # (B, T) pred2 = self.model(x2, node_mask=node_mask_batch)[:, target_idx] # (B, T) masked_diffs = pred1 - pred2 # (B, T) - multivariate differences # MSE over all B*T elements mse = crit(masked_diffs, target_diffs) node_probs, _ = torch.nn.functional.softmax(node_params, dim=0) m = torch.distributions.Bernoulli(probs=node_probs) ent = m.entropy().mean() loss = mse \ + self.beta * torch.maximum(torch.tensor([0.], device=self.device), node_weight.sum() - self.free_edges) \ - self.entropy * ent loss.backward() optim.step() if self.verbose: with torch.no_grad(): r2 = r2_score( target_diffs.detach().cpu().numpy().ravel(), masked_diffs.detach().cpu().numpy().ravel() ) if target_diffs.numel() > 1 else -666 print(f'iter: {iter} | loss: {loss.item():.4f} | mse: {mse.item():.4f} | r2: {r2:.3f} | active nodes: {(node_weight > 0.5).sum().item()} / {self.N} | entropy: {ent.item():.4f}', end='\r') # Post-training evaluation if self.verbose: print() with torch.no_grad(): final_node_probs, _ = torch.nn.functional.softmax(node_params.data, dim=0) subset_mask = (final_node_probs > 0.5).float() subset_mask_batch = subset_mask.view(1, -1).expand(B, -1) pred1_sub = self.model(x1, node_mask=subset_mask_batch)[:, target_idx] # (B, T) pred2_sub = self.model(x2, node_mask=subset_mask_batch)[:, target_idx] # (B, T) subset_diffs = pred1_sub - pred2_sub # (B, T) subset_mse = torch.nn.functional.mse_loss(subset_diffs, target_diffs).item() subset_r2 = r2_score( target_diffs.detach().cpu().numpy().ravel(), subset_diffs.detach().cpu().numpy().ravel() ) if target_diffs.numel() > 1 else -666 num_selected = (subset_mask > 0.5).sum().item() print("=" * 50) print("POST-TRAINING EVALUATION (nodes > 0.5)") print("=" * 50) print(f"Selected nodes: {num_selected} / {self.N} ({100 * num_selected / self.N:.1f}%)") print(f"Target Δf mean: {target_diffs.mean().item():.6f}") print(f"Subset Δf mean: {subset_diffs.mean().item():.6f}") print(f"MSE: {subset_mse:.6f}") print(f"R² (across {B}x{len(target_idx)} elements): {subset_r2:.4f}") print("=" * 50) node_scores, _ = torch.nn.functional.softmax(node_params.data, dim=0) # Package results - single set of scores for all pairs node_names = np.array(self.model.homo_names) result_df = pd.DataFrame({ "node": node_names, "score": node_scores.detach().cpu().numpy(), }) if return_weights: return result_df, node_scores.detach().cpu().numpy() return result_df
[docs] def tune( self, x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]] = None, min_fidelity: float = 0.9, beta_step: float = 1.5, max_trials: int = 20, verbose: bool = True, target: str = 'edge', **explain_kwargs, ): """ Tune beta parameter to find maximum sparsity while maintaining fidelity. For contrastive explanations, fidelity is measured as how well the subset preserves the prediction difference |f(x1) - f(x2)| across all pairs. Parameters ---------- x1, x2 : torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in]) Input data pairs for explanation. When B > 1, learns a single mask that works well across all pairs simultaneously. target_idx : int or list[int], optional Target output indices to explain. min_fidelity : float, optional (default=0.9) Minimum fidelity threshold (1 - mean_relative_error) to maintain. beta_step : float, optional (default=1.5) Multiplicative step size for beta adjustment. max_trials : int, optional (default=20) Maximum number of beta adjustments to try. verbose : bool, optional (default=True) Whether to print search progress. target : str, optional (default='edge') Whether to tune for 'edge' or 'node' level attributions. **explain_kwargs : dict, optional Override any explainer parameters during tuning. Returns ------- dict Results containing optimal beta, achieved fidelity, number of elements, and final DataFrame. """ if target not in ['edge', 'node']: raise ValueError(f"target must be 'edge' or 'node', got '{target}'") x1, x2 = x1.to(self.device), x2.to(self.device) if x1.dim() == 1: x1 = x1.unsqueeze(0) if x2.dim() == 1: x2 = x2.unsqueeze(0) B = x1.size(0) if isinstance(target_idx, int): target_idx = [target_idx] if verbose: print("=" * 60) print("BETA TUNING - Contrastive Explainer") print("=" * 60) print(f"Target: Find max beta with fidelity >= {min_fidelity:.3f}") print(f"Explanation target: {target}") print(f"Batch size: {B}") print(f"Starting beta: {self.beta:.4f}") print(f"Step size: {beta_step:.2f}x") print("=" * 60) # Store original settings original_settings = { 'beta': self.beta, 'iters': self.iters, 'lr': self.lr, 'weight_decay': self.weight_decay, 'free_edges': self.free_edges, 'prior': self.prior, 'tau0': self.tau0, 'min_tau': self.min_tau, 'hard': self.hard, 'entropy': self.entropy, 'verbose': self.verbose } # Apply parameter overrides for param, value in explain_kwargs.items(): if hasattr(self, param): setattr(self, param, value) def evaluate_beta(beta_val): """Evaluate performance for a given beta value using batch optimization.""" num_elements = self.E if target == 'edge' else self.N weights = torch.stack(( self.prior * torch.ones(num_elements, dtype=torch.float32, device=self.device), -self.prior * torch.ones(num_elements, dtype=torch.float32, device=self.device) ), dim=0) params = torch.nn.Parameter(weights) crit = torch.nn.MSELoss() optim = self.optimizer([params], lr=self.lr, weight_decay=self.weight_decay) tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters) # Get target differences for ALL pairs - keep as multivariate with torch.no_grad(): pred1_full = self.model(x1)[:, target_idx] if target_idx else self.model(x1) # (B, T) pred2_full = self.model(x2)[:, target_idx] if target_idx else self.model(x2) # (B, T) target_diffs = pred1_full - pred2_full # (B, T) - multivariate # Training loop - single mask for all pairs for iter in range(self.iters): optim.zero_grad() tau = max(self.tau0 * tau_decay_rate ** iter, self.min_tau) weight, _ = torch.nn.functional.gumbel_softmax(params, dim=0, hard=self.hard, tau=tau) # Broadcast mask to all samples mask_batch = weight.view(1, -1).expand(B, -1) if target == 'edge': pred1 = self.model(x1, edge_mask=mask_batch)[:, target_idx] if target_idx else self.model(x1, edge_mask=mask_batch) # (B, T) pred2 = self.model(x2, edge_mask=mask_batch)[:, target_idx] if target_idx else self.model(x2, edge_mask=mask_batch) # (B, T) else: pred1 = self.model(x1, node_mask=mask_batch)[:, target_idx] if target_idx else self.model(x1, node_mask=mask_batch) # (B, T) pred2 = self.model(x2, node_mask=mask_batch)[:, target_idx] if target_idx else self.model(x2, node_mask=mask_batch) # (B, T) masked_diffs = pred1 - pred2 # (B, T) - multivariate mse = crit(masked_diffs, target_diffs) probs, _ = torch.nn.functional.softmax(params, dim=0) m = torch.distributions.Bernoulli(probs=probs) ent = m.entropy().mean() loss = mse + beta_val * torch.maximum(torch.tensor([0.], device=self.device), weight.sum() - self.free_edges) - self.entropy * ent loss.backward() optim.step() # Evaluate final performance with torch.no_grad(): final_probs, _ = torch.nn.functional.softmax(params.data, dim=0) subset_mask = (final_probs > 0.5).float() subset_mask_batch = subset_mask.view(1, -1).expand(B, -1) if target == 'edge': pred1_sub = self.model(x1, edge_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x1, edge_mask=subset_mask_batch) # (B, T) pred2_sub = self.model(x2, edge_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x2, edge_mask=subset_mask_batch) # (B, T) else: pred1_sub = self.model(x1, node_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x1, node_mask=subset_mask_batch) # (B, T) pred2_sub = self.model(x2, node_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x2, node_mask=subset_mask_batch) # (B, T) subset_diffs = pred1_sub - pred2_sub # (B, T) - multivariate # Fidelity based on MSE (lower is better, so 1 - normalized_mse) mse_val = torch.nn.functional.mse_loss(subset_diffs, target_diffs).item() target_var = target_diffs.var().item() + 1e-8 fidelity = 1.0 - mse_val / target_var # R²-like metric num_selected = (subset_mask > 0.5).sum().item() return fidelity, num_selected, params # Adaptive search current_beta = self.beta best_beta = current_beta best_fidelity = 0.0 total_elements = self.E if target == 'edge' else self.N best_elements = total_elements best_params = None if verbose: print(f"\nStep 1: Evaluating starting beta = {current_beta:.4f}") try: initial_fidelity, initial_elements, initial_params = evaluate_beta(current_beta) element_type = "Edges" if target == 'edge' else "Nodes" if verbose: print(f" → Fidelity = {initial_fidelity:.4f}, {element_type} = {initial_elements}") best_beta = current_beta best_fidelity = initial_fidelity best_elements = initial_elements best_params = initial_params # Determine search direction if initial_fidelity >= min_fidelity: search_direction = "up" if verbose: print(f" ✓ Good fidelity! Searching upward for more sparsity...") else: search_direction = "down" if verbose: print(f" ✗ Poor fidelity! Searching downward...") # Search for trial in range(max_trials): if search_direction == "up": test_beta = current_beta * beta_step else: test_beta = current_beta / beta_step if verbose: print(f"\nTrial {trial + 1}: Testing beta = {test_beta:.4f}") try: test_fidelity, test_elements, test_params = evaluate_beta(test_beta) if verbose: print(f" → Fidelity = {test_fidelity:.4f}, {element_type} = {test_elements}") if search_direction == "up": if test_fidelity >= min_fidelity: best_beta = test_beta best_fidelity = test_fidelity best_elements = test_elements best_params = test_params current_beta = test_beta if verbose: print(f" ✓ Still good! New best: β={best_beta:.4f}") else: if verbose: print(f" ✗ Fidelity dropped, boundary found!") break else: if test_fidelity >= min_fidelity: best_beta = test_beta best_fidelity = test_fidelity best_elements = test_elements best_params = test_params if verbose: print(f" ✓ Fidelity recovered! Optimal: β={best_beta:.4f}") break else: current_beta = test_beta if verbose: print(f" ✗ Still poor, continuing...") if test_beta > 100 or test_beta < 0.001: if verbose: print(f" ⚠ Beta limit reached, stopping") break except Exception as e: if verbose: print(f" Error: {e}") break except Exception as e: if verbose: print(f"Error with initial beta: {e}") best_beta = self.beta # Restore original settings for param, value in original_settings.items(): setattr(self, param, value) self.beta = best_beta # Create final dataframe final_df = None if best_params is not None: scores, _ = torch.nn.functional.softmax(best_params.data, dim=0).detach().cpu().numpy() if target == 'edge': src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()] final_df = pd.DataFrame({'source': src, 'target': dst, 'score': scores}) else: node_names = np.array(self.model.homo_names) final_df = pd.DataFrame({'node': node_names, 'score': scores}) if verbose: print("\n" + "=" * 60) print("TUNING COMPLETE") print("=" * 60) print(f"Starting beta: {original_settings['beta']:.4f}") print(f"Optimal beta: {best_beta:.4f}") print(f"Final fidelity (across {B} pairs): {best_fidelity:.4f}") element_type_lower = "edges" if target == 'edge' else "nodes" print(f"Selected {element_type_lower}: {best_elements} / {total_elements} ({100 * best_elements / total_elements:.1f}%)") print("=" * 60) return { 'starting_beta': original_settings['beta'], 'optimal_beta': best_beta, 'beta_change_factor': best_beta / original_settings['beta'], 'achieved_fidelity': best_fidelity, 'num_elements': best_elements, 'total_elements': total_elements, 'sparsity_ratio': best_elements / total_elements, 'result_df': final_df, 'target': target, 'batch_size': B }