import copy
from typing import Union, List
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import r2_score
from gsnn.interpret._kwargs_utils import normalize_model_kwargs
[docs]class ContrastiveGSNNExplainer:
r"""Edge/node mask optimiser for *contrastive* explanations.
This explainer learns a binary mask *m∈{0,1}^{E|N}* that maximises fidelity
between the **absolute prediction difference** on the masked graph and the
absolute difference on the full graph, while simultaneously penalising mask
size::
|Δf|(m) = | f(x₁; m)[target_idx] − f(x₂; m)[target_idx] | (multivariate)
L = MSE(|Δf|(m), |Δf|(1)) # over all B×T elements
+ β max(0, ‖m‖₁ − free_elements)
− λ H(m) (optional entropy term)
Here *m* is obtained via a differentiable Gumbel-Softmax relaxation so the
optimisation can be performed with vanilla back-prop. After convergence
the importance score is the softmax probability *p_i = P(m_i=1)*.
Interpretation
--------------
* ``score_i → 1`` element i is essential for reproducing the magnitude of
the prediction difference ``|Δf|``.
* ``score_i → 0`` element i can be removed without affecting ``|Δf|``.
Note
----
The objective targets the *magnitude* ``|Δf|`` rather than the signed
difference ``Δf``. A mask that flips the sign of the prediction difference
while preserving its magnitude is therefore considered faithful by this
objective. In practice this is rarely an issue because optimisation starts
from the full mask (where the sign of ``Δf`` is correct) and progressively
sparsifies, but it is worth being aware of for pathological models or
aggressive ``beta``.
Parameters
----------
model : torch.nn.Module
Trained GSNN model (its parameters are *frozen* during explanation).
data : torch_geometric.data.Data
Graph data object (only metadata are used).
ignore_cuda : bool, optional (default=False)
Force CPU even if CUDA is available.
gumbel_softmax : bool, optional (default=True)
Use the Gumbel-Softmax re-parameterisation; otherwise plain Softmax.
hard : bool, optional (default=False)
Use the straight-through estimator to obtain discrete masks at test
time while keeping gradients continuous.
tau0 : float, optional (default=3.0)
Initial temperature for the (hard) Gumbel-Softmax.
min_tau : float, optional (default=0.5)
Minimum temperature reached after exponential decay.
prior : float, optional (default=1.0)
Initial bias added to the positive/negative logits.
iters : int, optional (default=250)
Number of optimisation steps.
lr : float, optional (default=1e-2)
Learning rate for the optimiser.
weight_decay : float, optional (default=1e-5)
Weight decay applied to the mask logits.
free_edges : int, optional (default=0)
Number of elements allowed before the sparsity penalty activates.
beta : float, optional (default=1.0)
Coefficient of the sparsity term.
entropy : float, optional (default=0.0)
Strength of the entropy bonus (encourages exploration).
verbose : bool, optional (default=True)
Print progress information during optimisation.
scale_mse_by_variance : bool, optional (default=True)
If True, normalise the MSE term by ``Var(target_diffs)`` so that the
fidelity loss is scale-invariant across pairs (an ``1 - R²`` style
objective). This makes ``beta`` interpretable across batches whose
``Δf`` magnitudes differ. Falls back to plain MSE when the target
difference tensor has fewer than 2 elements.
Example
-------
>>> explainer = ContrastiveGSNNExplainer(model, data, iters=400, beta=5)
>>> # Edge-level attributions
>>> edge_df = explainer.explain(x1, x2, target_idx=0, target='edge')
>>> edge_df.sort_values('score', ascending=False).head()
>>> # Node-level attributions
>>> node_df = explainer.explain(x1, x2, target_idx=0, target='node')
>>> node_df.sort_values('score', ascending=False).head()
"""
def __init__(
self,
model,
data,
ignore_cuda: bool = False,
gumbel_softmax: bool = True,
hard: bool = False,
tau0: float = 3.0,
min_tau: float = 0.5,
prior: float = 1.0,
iters: int = 250,
lr: float = 1e-2,
weight_decay: float = 1e-5,
free_edges: int = 0,
beta: float = 1.0,
verbose: bool = True,
optimizer=torch.optim.Adam,
entropy: float = 0.0,
scale_mse_by_variance: bool = True,
) -> None:
"""
Contrastive version of GSNNExplainer for explaining prediction differences.
Adapted from the methods presented in `GNNExplainer` (https://arxiv.org/abs/1903.03894).
"""
self.free_edges = free_edges
self.iters = iters
self.lr = lr
self.weight_decay = weight_decay
self.beta = beta
self.verbose = verbose
self.optimizer = optimizer
self.gumbel_softmax = gumbel_softmax
self.prior = prior
self.hard = hard
self.min_tau = min_tau
self.tau0 = tau0
self.data = data
self.device = 'cuda' if (torch.cuda.is_available() and not ignore_cuda) else 'cpu'
self.entropy = entropy
self.scale_mse_by_variance = scale_mse_by_variance
model = copy.deepcopy(model)
model = model.eval()
model = model.to(self.device)
# freeze model parameters
for p in model.parameters():
p.requires_grad = False
self.model = model
self.E = model.edge_index.size(1)
self.N = model.num_nodes
[docs] def explain(
self,
x1: torch.Tensor,
x2: torch.Tensor,
target_idx: Union[int, List[int]],
*,
return_weights: bool = False,
target: str = 'edge',
model_kwargs1=None,
model_kwargs2=None,
) -> pd.DataFrame:
"""Compute attributions for *f(x₁) − f(x₂)*.
Initializes and runs gradient descent to select a minimal subset of
elements that preserve the prediction difference between x1 and x2.
When given multiple pairs (batch), learns ONE mask that works well
across ALL pairs by treating the differences as a multi-output objective.
This is much faster than per-sample optimization.
Parameters
----------
x1, x2 : torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch)
Two input feature tensors. They must have identical batch size.
When B > 1, learns a single mask that preserves the prediction
difference across all pairs simultaneously.
target_idx : int or list[int]
Output dimension(s) to explain. If a list is provided the
attributions refer to the **sum** of those outputs.
return_weights : bool, optional (default=False)
Whether to return raw weights along with the DataFrame.
target : str, optional (default='edge')
Whether to return 'edge' or 'node' level attributions.
model_kwargs1, model_kwargs2 : dict, optional (default=None)
Per-side keyword arguments forwarded to every ``self.model(x1, ...)``
/ ``self.model(x2, ...)`` call (e.g. ``{'x_fn': x_fn_1}`` for
models trained with ``node_activity=True``). Tensor values must
have leading dim equal to ``x1.shape[0]`` / ``x2.shape[0]``.
``edge_mask`` / ``node_mask`` are reserved.
Returns
-------
pd.DataFrame
If target='edge': columns ['source', 'target', 'score'] for edge attributions.
If target='node': columns ['node', 'score'] for node attributions.
"""
if target not in ['edge', 'node']:
raise ValueError(f"target must be 'edge' or 'node', got '{target}'")
if target == 'edge':
return self._explain_edges(x1, x2, target_idx, return_weights,
model_kwargs1=model_kwargs1, model_kwargs2=model_kwargs2)
else:
return self._explain_nodes(x1, x2, target_idx, return_weights,
model_kwargs1=model_kwargs1, model_kwargs2=model_kwargs2)
def _explain_edges(
self,
x1: torch.Tensor,
x2: torch.Tensor,
target_idx: Union[int, List[int]],
return_weights: bool = False,
model_kwargs1=None,
model_kwargs2=None,
) -> pd.DataFrame:
"""Compute edge-level attributions for *f(x₁) − f(x₂)*.
Learns ONE mask across all sample pairs by treating the differences
as a multi-output objective. This is much faster than per-sample optimization.
"""
mk1 = normalize_model_kwargs(model_kwargs1)
mk2 = normalize_model_kwargs(model_kwargs2)
x1, x2 = x1.to(self.device), x2.to(self.device)
# Ensure batch dimension
if x1.dim() == 1:
x1 = x1.unsqueeze(0)
if x2.dim() == 1:
x2 = x2.unsqueeze(0)
B = x1.size(0) # batch size
if isinstance(target_idx, int):
target_idx = [target_idx]
# Initialize edge parameters (single mask for all samples)
weights = torch.stack((
self.prior * torch.ones(self.E, dtype=torch.float32, device=self.device),
-self.prior * torch.ones(self.E, dtype=torch.float32, device=self.device)
), dim=0)
edge_params = torch.nn.Parameter(weights)
# Optimizer and loss
crit = torch.nn.MSELoss()
optim = self.optimizer([edge_params], lr=self.lr, weight_decay=self.weight_decay)
# Calculate tau decay rate
tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters)
# Get target |prediction differences| for ALL pairs (baseline) - keep as multivariate
with torch.no_grad():
pred1_full = self.model(x1, **mk1)[:, target_idx] # (B, T)
pred2_full = self.model(x2, **mk2)[:, target_idx] # (B, T)
target_diffs = (pred1_full - pred2_full).abs() # (B, T) - |Δf|
target_var = target_diffs.var().detach() if (self.scale_mse_by_variance and target_diffs.numel() > 1) else None
if self.verbose:
print(f"Batch size: {B}, Target dims: {len(target_idx)}")
print(f"Target |Δf| mean: {target_diffs.mean().item():.6f}, std: {target_diffs.std().item():.6f}")
# Optimization loop - learns ONE mask for all pairs
for iter in range(self.iters):
optim.zero_grad()
tau = max(self.tau0 * tau_decay_rate ** iter, self.min_tau)
edge_weight, _ = torch.nn.functional.gumbel_softmax(edge_params, dim=0, hard=self.hard, tau=tau)
# Broadcast mask to all samples: (1, E) -> used for all B samples
edge_mask_batch = edge_weight.view(1, -1).expand(B, -1) # (B, E)
# Forward pass for all pairs at once - keep as multivariate
pred1 = self.model(x1, edge_mask=edge_mask_batch, **mk1)[:, target_idx] # (B, T)
pred2 = self.model(x2, edge_mask=edge_mask_batch, **mk2)[:, target_idx] # (B, T)
masked_diffs = (pred1 - pred2).abs() # (B, T) - |Δf|
# MSE over all B*T elements
mse = crit(masked_diffs, target_diffs)
if target_var is not None:
mse = mse / (target_var + 1e-8)
edge_probs, _ = torch.nn.functional.softmax(edge_params, dim=0)
m = torch.distributions.Bernoulli(probs=edge_probs)
ent = m.entropy().mean()
loss = mse \
+ self.beta * torch.maximum(torch.tensor([0.], device=self.device), edge_weight.sum() - self.free_edges) \
- self.entropy * ent
loss.backward()
optim.step()
if self.verbose:
with torch.no_grad():
r2 = r2_score(
target_diffs.detach().cpu().numpy().ravel(),
masked_diffs.detach().cpu().numpy().ravel()
) if target_diffs.numel() > 1 else -666
print(f'iter: {iter} | loss: {loss.item():.4f} | mse: {mse.item():.4f} | r2: {r2:.3f} | active edges: {(edge_weight > 0.5).sum().item()} / {self.E} | entropy: {ent.item():.4f}', end='\r')
# Post-training evaluation
if self.verbose:
print()
with torch.no_grad():
final_edge_probs, _ = torch.nn.functional.softmax(edge_params.data, dim=0)
subset_mask = (final_edge_probs > 0.5).float()
subset_mask_batch = subset_mask.view(1, -1).expand(B, -1)
pred1_sub = self.model(x1, edge_mask=subset_mask_batch, **mk1)[:, target_idx] # (B, T)
pred2_sub = self.model(x2, edge_mask=subset_mask_batch, **mk2)[:, target_idx] # (B, T)
subset_diffs = (pred1_sub - pred2_sub).abs() # (B, T) - |Δf|
subset_mse = torch.nn.functional.mse_loss(subset_diffs, target_diffs).item()
subset_r2 = r2_score(
target_diffs.detach().cpu().numpy().ravel(),
subset_diffs.detach().cpu().numpy().ravel()
) if target_diffs.numel() > 1 else -666
num_selected = (subset_mask > 0.5).sum().item()
print("=" * 50)
print("POST-TRAINING EVALUATION (edges > 0.5)")
print("=" * 50)
print(f"Selected edges: {num_selected} / {self.E} ({100 * num_selected / self.E:.1f}%)")
print(f"Target |Δf| mean: {target_diffs.mean().item():.6f}")
print(f"Subset |Δf| mean: {subset_diffs.mean().item():.6f}")
print(f"MSE: {subset_mse:.6f}")
print(f"R² (across {B}x{len(target_idx)} elements): {subset_r2:.4f}")
print("=" * 50)
edge_scores, _ = torch.nn.functional.softmax(edge_params.data, dim=0)
# Package results - single set of scores for all pairs
src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()]
result_df = pd.DataFrame({
"source": src,
"target": dst,
"score": edge_scores.detach().cpu().numpy(),
})
if return_weights:
return result_df, edge_scores.detach().cpu().numpy()
return result_df
def _explain_nodes(
self,
x1: torch.Tensor,
x2: torch.Tensor,
target_idx: Union[int, List[int]],
return_weights: bool = False,
model_kwargs1=None,
model_kwargs2=None,
) -> pd.DataFrame:
"""Compute node-level attributions for *f(x₁) − f(x₂)*.
Learns ONE mask across all sample pairs by treating the differences
as a multi-output objective. This is much faster than per-sample optimization.
"""
mk1 = normalize_model_kwargs(model_kwargs1)
mk2 = normalize_model_kwargs(model_kwargs2)
x1, x2 = x1.to(self.device), x2.to(self.device)
# Ensure batch dimension
if x1.dim() == 1:
x1 = x1.unsqueeze(0)
if x2.dim() == 1:
x2 = x2.unsqueeze(0)
B = x1.size(0) # batch size
if isinstance(target_idx, int):
target_idx = [target_idx]
# Initialize node parameters (single mask for all samples)
weights = torch.stack((
self.prior * torch.ones(self.N, dtype=torch.float32, device=self.device),
-self.prior * torch.ones(self.N, dtype=torch.float32, device=self.device)
), dim=0)
node_params = torch.nn.Parameter(weights)
# Optimizer and loss
crit = torch.nn.MSELoss()
optim = self.optimizer([node_params], lr=self.lr, weight_decay=self.weight_decay)
# Calculate tau decay rate
tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters)
# Get target |prediction differences| for ALL pairs (baseline) - keep as multivariate
with torch.no_grad():
pred1_full = self.model(x1, **mk1)[:, target_idx] # (B, T)
pred2_full = self.model(x2, **mk2)[:, target_idx] # (B, T)
target_diffs = (pred1_full - pred2_full).abs() # (B, T) - |Δf|
target_var = target_diffs.var().detach() if (self.scale_mse_by_variance and target_diffs.numel() > 1) else None
if self.verbose:
print(f"Batch size: {B}, Target dims: {len(target_idx)}")
print(f"Target |Δf| mean: {target_diffs.mean().item():.6f}, std: {target_diffs.std().item():.6f}")
# Optimization loop - learns ONE mask for all pairs
for iter in range(self.iters):
optim.zero_grad()
tau = max(self.tau0 * tau_decay_rate ** iter, self.min_tau)
node_weight, _ = torch.nn.functional.gumbel_softmax(node_params, dim=0, hard=self.hard, tau=tau)
# Broadcast mask to all samples: (1, N) -> used for all B samples
node_mask_batch = node_weight.view(1, -1).expand(B, -1) # (B, N)
# Forward pass for all pairs at once - keep as multivariate
pred1 = self.model(x1, node_mask=node_mask_batch, **mk1)[:, target_idx] # (B, T)
pred2 = self.model(x2, node_mask=node_mask_batch, **mk2)[:, target_idx] # (B, T)
masked_diffs = (pred1 - pred2).abs() # (B, T) - |Δf|
# MSE over all B*T elements
mse = crit(masked_diffs, target_diffs)
if target_var is not None:
mse = mse / (target_var + 1e-8)
node_probs, _ = torch.nn.functional.softmax(node_params, dim=0)
m = torch.distributions.Bernoulli(probs=node_probs)
ent = m.entropy().mean()
loss = mse \
+ self.beta * torch.maximum(torch.tensor([0.], device=self.device), node_weight.sum() - self.free_edges) \
- self.entropy * ent
loss.backward()
optim.step()
if self.verbose:
with torch.no_grad():
r2 = r2_score(
target_diffs.detach().cpu().numpy().ravel(),
masked_diffs.detach().cpu().numpy().ravel()
) if target_diffs.numel() > 1 else -666
print(f'iter: {iter} | loss: {loss.item():.4f} | mse: {mse.item():.4f} | r2: {r2:.3f} | active nodes: {(node_weight > 0.5).sum().item()} / {self.N} | entropy: {ent.item():.4f}', end='\r')
# Post-training evaluation
if self.verbose:
print()
with torch.no_grad():
final_node_probs, _ = torch.nn.functional.softmax(node_params.data, dim=0)
subset_mask = (final_node_probs > 0.5).float()
subset_mask_batch = subset_mask.view(1, -1).expand(B, -1)
pred1_sub = self.model(x1, node_mask=subset_mask_batch, **mk1)[:, target_idx] # (B, T)
pred2_sub = self.model(x2, node_mask=subset_mask_batch, **mk2)[:, target_idx] # (B, T)
subset_diffs = (pred1_sub - pred2_sub).abs() # (B, T) - |Δf|
subset_mse = torch.nn.functional.mse_loss(subset_diffs, target_diffs).item()
subset_r2 = r2_score(
target_diffs.detach().cpu().numpy().ravel(),
subset_diffs.detach().cpu().numpy().ravel()
) if target_diffs.numel() > 1 else -666
num_selected = (subset_mask > 0.5).sum().item()
print("=" * 50)
print("POST-TRAINING EVALUATION (nodes > 0.5)")
print("=" * 50)
print(f"Selected nodes: {num_selected} / {self.N} ({100 * num_selected / self.N:.1f}%)")
print(f"Target |Δf| mean: {target_diffs.mean().item():.6f}")
print(f"Subset |Δf| mean: {subset_diffs.mean().item():.6f}")
print(f"MSE: {subset_mse:.6f}")
print(f"R² (across {B}x{len(target_idx)} elements): {subset_r2:.4f}")
print("=" * 50)
node_scores, _ = torch.nn.functional.softmax(node_params.data, dim=0)
# Package results - single set of scores for all pairs
node_names = np.array(self.model.homo_names)
result_df = pd.DataFrame({
"node": node_names,
"score": node_scores.detach().cpu().numpy(),
})
if return_weights:
return result_df, node_scores.detach().cpu().numpy()
return result_df
[docs] def tune(
self,
x1: torch.Tensor,
x2: torch.Tensor,
target_idx: Union[int, List[int]] = None,
min_fidelity: float = 0.9,
beta_step: float = 1.5,
max_trials: int = 20,
verbose: bool = True,
target: str = 'edge',
**explain_kwargs,
):
"""
Tune beta parameter to find maximum sparsity while maintaining fidelity.
For contrastive explanations, fidelity is measured as how well the subset
preserves the prediction difference |f(x1) - f(x2)| across all pairs.
Parameters
----------
x1, x2 : torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in])
Input data pairs for explanation. When B > 1, learns a single mask
that works well across all pairs simultaneously.
target_idx : int or list[int], optional
Target output indices to explain.
min_fidelity : float, optional (default=0.9)
Minimum fidelity threshold (1 - mean_relative_error) to maintain.
beta_step : float, optional (default=1.5)
Multiplicative step size for beta adjustment.
max_trials : int, optional (default=20)
Maximum number of beta adjustments to try.
verbose : bool, optional (default=True)
Whether to print search progress.
target : str, optional (default='edge')
Whether to tune for 'edge' or 'node' level attributions.
**explain_kwargs : dict, optional
Override any explainer parameters during tuning.
Returns
-------
dict
Results containing optimal beta, achieved fidelity, number of elements,
and final DataFrame.
"""
if target not in ['edge', 'node']:
raise ValueError(f"target must be 'edge' or 'node', got '{target}'")
x1, x2 = x1.to(self.device), x2.to(self.device)
if x1.dim() == 1:
x1 = x1.unsqueeze(0)
if x2.dim() == 1:
x2 = x2.unsqueeze(0)
B = x1.size(0)
if isinstance(target_idx, int):
target_idx = [target_idx]
if verbose:
print("=" * 60)
print("BETA TUNING - Contrastive Explainer")
print("=" * 60)
print(f"Target: Find max beta with fidelity >= {min_fidelity:.3f}")
print(f"Explanation target: {target}")
print(f"Batch size: {B}")
print(f"Starting beta: {self.beta:.4f}")
print(f"Step size: {beta_step:.2f}x")
print("=" * 60)
# Store original settings
original_settings = {
'beta': self.beta,
'iters': self.iters,
'lr': self.lr,
'weight_decay': self.weight_decay,
'free_edges': self.free_edges,
'prior': self.prior,
'tau0': self.tau0,
'min_tau': self.min_tau,
'hard': self.hard,
'entropy': self.entropy,
'verbose': self.verbose,
'scale_mse_by_variance': self.scale_mse_by_variance
}
# Apply parameter overrides
for param, value in explain_kwargs.items():
if hasattr(self, param):
setattr(self, param, value)
def evaluate_beta(beta_val):
"""Evaluate performance for a given beta value using batch optimization."""
num_elements = self.E if target == 'edge' else self.N
weights = torch.stack((
self.prior * torch.ones(num_elements, dtype=torch.float32, device=self.device),
-self.prior * torch.ones(num_elements, dtype=torch.float32, device=self.device)
), dim=0)
params = torch.nn.Parameter(weights)
crit = torch.nn.MSELoss()
optim = self.optimizer([params], lr=self.lr, weight_decay=self.weight_decay)
tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters)
# Get target |differences| for ALL pairs - keep as multivariate
with torch.no_grad():
pred1_full = self.model(x1)[:, target_idx] if target_idx else self.model(x1) # (B, T)
pred2_full = self.model(x2)[:, target_idx] if target_idx else self.model(x2) # (B, T)
target_diffs = (pred1_full - pred2_full).abs() # (B, T) - |Δf|
target_var = target_diffs.var().detach() if (self.scale_mse_by_variance and target_diffs.numel() > 1) else None
# Training loop - single mask for all pairs
for iter in range(self.iters):
optim.zero_grad()
tau = max(self.tau0 * tau_decay_rate ** iter, self.min_tau)
weight, _ = torch.nn.functional.gumbel_softmax(params, dim=0, hard=self.hard, tau=tau)
# Broadcast mask to all samples
mask_batch = weight.view(1, -1).expand(B, -1)
if target == 'edge':
pred1 = self.model(x1, edge_mask=mask_batch)[:, target_idx] if target_idx else self.model(x1, edge_mask=mask_batch) # (B, T)
pred2 = self.model(x2, edge_mask=mask_batch)[:, target_idx] if target_idx else self.model(x2, edge_mask=mask_batch) # (B, T)
else:
pred1 = self.model(x1, node_mask=mask_batch)[:, target_idx] if target_idx else self.model(x1, node_mask=mask_batch) # (B, T)
pred2 = self.model(x2, node_mask=mask_batch)[:, target_idx] if target_idx else self.model(x2, node_mask=mask_batch) # (B, T)
masked_diffs = (pred1 - pred2).abs() # (B, T) - |Δf|
mse = crit(masked_diffs, target_diffs)
if target_var is not None:
mse = mse / (target_var + 1e-8)
probs, _ = torch.nn.functional.softmax(params, dim=0)
m = torch.distributions.Bernoulli(probs=probs)
ent = m.entropy().mean()
loss = mse + beta_val * torch.maximum(torch.tensor([0.], device=self.device), weight.sum() - self.free_edges) - self.entropy * ent
loss.backward()
optim.step()
# Evaluate final performance
with torch.no_grad():
final_probs, _ = torch.nn.functional.softmax(params.data, dim=0)
subset_mask = (final_probs > 0.5).float()
subset_mask_batch = subset_mask.view(1, -1).expand(B, -1)
if target == 'edge':
pred1_sub = self.model(x1, edge_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x1, edge_mask=subset_mask_batch) # (B, T)
pred2_sub = self.model(x2, edge_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x2, edge_mask=subset_mask_batch) # (B, T)
else:
pred1_sub = self.model(x1, node_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x1, node_mask=subset_mask_batch) # (B, T)
pred2_sub = self.model(x2, node_mask=subset_mask_batch)[:, target_idx] if target_idx else self.model(x2, node_mask=subset_mask_batch) # (B, T)
subset_diffs = (pred1_sub - pred2_sub).abs() # (B, T) - |Δf|
# Fidelity based on MSE (lower is better, so 1 - normalized_mse)
mse_val = torch.nn.functional.mse_loss(subset_diffs, target_diffs).item()
target_var = target_diffs.var().item() + 1e-8
fidelity = 1.0 - mse_val / target_var # R²-like metric
num_selected = (subset_mask > 0.5).sum().item()
return fidelity, num_selected, params
# Adaptive search
current_beta = self.beta
best_beta = current_beta
best_fidelity = 0.0
total_elements = self.E if target == 'edge' else self.N
best_elements = total_elements
best_params = None
if verbose:
print(f"\nStep 1: Evaluating starting beta = {current_beta:.4f}")
try:
initial_fidelity, initial_elements, initial_params = evaluate_beta(current_beta)
element_type = "Edges" if target == 'edge' else "Nodes"
if verbose:
print(f" → Fidelity = {initial_fidelity:.4f}, {element_type} = {initial_elements}")
best_beta = current_beta
best_fidelity = initial_fidelity
best_elements = initial_elements
best_params = initial_params
# Determine search direction
if initial_fidelity >= min_fidelity:
search_direction = "up"
if verbose:
print(f" ✓ Good fidelity! Searching upward for more sparsity...")
else:
search_direction = "down"
if verbose:
print(f" ✗ Poor fidelity! Searching downward...")
# Search
for trial in range(max_trials):
if search_direction == "up":
test_beta = current_beta * beta_step
else:
test_beta = current_beta / beta_step
if verbose:
print(f"\nTrial {trial + 1}: Testing beta = {test_beta:.4f}")
try:
test_fidelity, test_elements, test_params = evaluate_beta(test_beta)
if verbose:
print(f" → Fidelity = {test_fidelity:.4f}, {element_type} = {test_elements}")
if search_direction == "up":
if test_fidelity >= min_fidelity:
best_beta = test_beta
best_fidelity = test_fidelity
best_elements = test_elements
best_params = test_params
current_beta = test_beta
if verbose:
print(f" ✓ Still good! New best: β={best_beta:.4f}")
else:
if verbose:
print(f" ✗ Fidelity dropped, boundary found!")
break
else:
if test_fidelity >= min_fidelity:
best_beta = test_beta
best_fidelity = test_fidelity
best_elements = test_elements
best_params = test_params
if verbose:
print(f" ✓ Fidelity recovered! Optimal: β={best_beta:.4f}")
break
else:
current_beta = test_beta
if verbose:
print(f" ✗ Still poor, continuing...")
if test_beta > 100 or test_beta < 0.001:
if verbose:
print(f" ⚠ Beta limit reached, stopping")
break
except Exception as e:
if verbose:
print(f" Error: {e}")
break
except Exception as e:
if verbose:
print(f"Error with initial beta: {e}")
best_beta = self.beta
# Restore original settings
for param, value in original_settings.items():
setattr(self, param, value)
self.beta = best_beta
# Create final dataframe
final_df = None
if best_params is not None:
scores, _ = torch.nn.functional.softmax(best_params.data, dim=0).detach().cpu().numpy()
if target == 'edge':
src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()]
final_df = pd.DataFrame({'source': src, 'target': dst, 'score': scores})
else:
node_names = np.array(self.model.homo_names)
final_df = pd.DataFrame({'node': node_names, 'score': scores})
if verbose:
print("\n" + "=" * 60)
print("TUNING COMPLETE")
print("=" * 60)
print(f"Starting beta: {original_settings['beta']:.4f}")
print(f"Optimal beta: {best_beta:.4f}")
print(f"Final fidelity (across {B} pairs): {best_fidelity:.4f}")
element_type_lower = "edges" if target == 'edge' else "nodes"
print(f"Selected {element_type_lower}: {best_elements} / {total_elements} ({100 * best_elements / total_elements:.1f}%)")
print("=" * 60)
return {
'starting_beta': original_settings['beta'],
'optimal_beta': best_beta,
'beta_change_factor': best_beta / original_settings['beta'],
'achieved_fidelity': best_fidelity,
'num_elements': best_elements,
'total_elements': total_elements,
'sparsity_ratio': best_elements / total_elements,
'result_df': final_df,
'target': target,
'batch_size': B
}