import numpy as np
import torch
import pandas as pd
from sklearn.metrics import r2_score
import copy
[docs]class GSNNExplainer:
r"""Edge/node mask optimiser that produces *sparse* explanations.
The explainer learns a binary mask *m∈\{0,1\}^{E|N}* that maximises fidelity
between the model's prediction on the **masked** graph and the prediction
on the *full* graph while simultaneously penalising mask size::
L = MSE\bigl(f(x; m), f(x; 1)\bigr)
+ β \max(0, \|m\|₁ − free_elements)
− λ H(m) (optional entropy term)
Here *m* is obtained via a differentiable Gumbel-Softmax relaxation so the
optimisation can be performed with vanilla back-prop. After convergence
the importance score is the softmax probability *p_i = P(m_i=1)*.
Interpretation
--------------
* ``score_i → 1`` element i is essential for reproducing the original prediction.
* ``score_i → 0`` element i can be removed with little impact.
Parameters
----------
model : torch.nn.Module
Trained GSNN model (its parameters are *frozen* during explanation).
data : torch_geometric.data.Data
Graph data object (only metadata are used).
ignore_cuda : bool, optional (default=False)
Force CPU even if CUDA is available.
gumbel_softmax : bool, optional (default=True)
Use the Gumbel-Softmax re-parameterisation; otherwise plain Softmax.
hard : bool, optional (default=False)
Use the straight-through estimator to obtain discrete masks at test
time while keeping gradients continuous.
tau0 : float, optional (default=3.0)
Initial temperature for the (hard) Gumbel-Softmax.
min_tau : float, optional (default=0.5)
Minimum temperature reached after exponential decay.
prior : float, optional (default=1.0)
Initial bias added to the positive/negative logits.
iters : int, optional (default=250)
Number of optimisation steps.
lr : float, optional (default=1e-2)
Learning rate for the optimiser.
weight_decay : float, optional (default=1e-5)
Weight decay applied to the mask logits.
free_edges : int, optional (default=0)
Number of elements allowed before the sparsity penalty activates.
beta : float, optional (default=1.0)
Coefficient of the sparsity term.
entropy : float, optional (default=0.0)
Strength of the entropy bonus (encourages exploration).
Example
-------
>>> explainer = GSNNExplainer(model, data, iters=400, beta=5)
>>> # Edge-level attributions
>>> edge_df = explainer.explain(x, targets=[0], target='edge')
>>> edge_df.sort_values('score', ascending=False).head()
>>> # Node-level attributions
>>> node_df = explainer.explain(x, targets=[0], target='node')
>>> node_df.sort_values('score', ascending=False).head()
"""
[docs] def __init__(self, model, data, ignore_cuda=False, gumbel_softmax=True, hard=False, tau0=3, min_tau=0.5,
prior=1, iters=250, lr=1e-2, weight_decay=1e-5, free_edges=0, grad_norm_clip=0,
beta=1, verbose=True, optimizer=torch.optim.Adam, entropy=0):
'''
Adapted from the methods presented in `GNNExplainer` (https://arxiv.org/abs/1903.03894).
Args:
model torch.nn.Module GSNN Model
data pyg.Data GSNN processed graph data
beta float regularization scalar encouraging a minimal subset of edges
ignore_cuda bool whether to use cuda if available
hard bool discrete forward operation for gumbel-softmax
tau0 float initial temperature value for gumbel-softmax
min_tau float minimum temperature value for gumbel-softmax
prior float prior strength to initialize theta; value of 0 will make each element 0.5 prob of being selecting, value > 0 will make it more likely to be selected.
grad_norm_clip float gradient norm clipping value
verbose bool whether to print progress information during optimisation
optimizer torch.optim.Optimizer optimizer to use for training
entropy float entropy bonus strength
iters int number of optimisation steps
lr float learning rate for the optimiser
weight_decay float weight decay for the optimiser
free_edges int number of edges allowed before the sparsity penalty activates
Returns
None
'''
self.free_edges = free_edges
self.iters = iters
self.lr = lr
self.weight_decay = weight_decay
self.beta = beta
self.verbose = verbose
self.optimizer = optimizer
self.gumbel_softmax = gumbel_softmax
self.prior = prior
self.hard = hard
self.min_tau = min_tau
self.tau0 = tau0
self.grad_norm_clip = grad_norm_clip
self.data = data
self.device = 'cuda' if (torch.cuda.is_available() and not ignore_cuda) else 'cpu'
self.entropy = entropy
model = copy.deepcopy(model)
model = model.eval()
model = model.to(self.device)
# freeze model parameters
for p in model.parameters():
p.requires_grad = False
self.model = model
[docs] def explain(self, x, target_idx=None, return_weights=False, target='edge'):
'''
Initializes and runs gradient descent to select a minimal subset of edges or nodes that produce comparable predictions
to the full graph.
Parameters
----------
x : torch.tensor
Input features to explain; in shape (B, I).
targets : list, optional
Target output indices to explain.
return_weights : bool, optional (default=False)
Whether to return raw weights along with the DataFrame.
target : str, optional (default='edge')
Whether to return 'edge' or 'node' level attributions.
Returns
-------
pd.DataFrame
If target='edge': columns ['source', 'target', 'score'] for edge attributions.
If target='node': columns ['node', 'score'] for node attributions.
'''
if target not in ['edge', 'node']:
raise ValueError(f"target must be 'edge' or 'node', got '{target}'")
if target == 'edge':
return self._explain_edges(x, target_idx, return_weights)
elif target == 'node':
return self._explain_nodes(x, target_idx, return_weights)
def _explain_edges(self, x, targets=None, return_weights=False):
'''
Compute edge-level attributions using gradient descent optimization.
Parameters
----------
x : torch.tensor
Input features to explain; in shape (B, I).
targets : list, optional
Target output indices to explain.
return_weights : bool, optional (default=False)
Whether to return raw weights along with the DataFrame.
Returns
-------
pd.DataFrame
Columns ['source', 'target', 'score'] for edge attributions.
'''
weights = torch.stack((self.prior*torch.ones(self.model.edge_index.size(1), dtype=torch.float32, device=self.device, requires_grad=True),
-self.prior*torch.ones(self.model.edge_index.size(1), dtype=torch.float32, device=self.device, requires_grad=True)), dim=0)
edge_params = torch.nn.Parameter(weights)
# optimize parameter mask with objective
crit = torch.nn.MSELoss()
optim = self.optimizer([edge_params], lr=self.lr, weight_decay=self.weight_decay)
# calculate tau decay rate
tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters)
# get target predictions
with torch.no_grad():
target_preds = self.model(x)
if targets is not None:
target_preds = target_preds[:, targets]
for iter in range(self.iters):
optim.zero_grad()
tau = max(self.tau0 * tau_decay_rate**iter, self.min_tau)
edge_weight, _ = torch.nn.functional.gumbel_softmax(edge_params, dim=0, hard=self.hard, tau=tau)
out = self.model(x, edge_mask=edge_weight.view(1, -1))
if targets is not None:
out = out[:, targets]
mse = crit(out, target_preds)
edge_probs, _ = torch.nn.functional.softmax(edge_params, dim=0)
m = torch.distributions.Bernoulli(probs=edge_probs)
ent = m.entropy().mean()
loss = mse \
+ self.beta*torch.maximum(torch.tensor([0.], device=x.device), edge_weight.sum() - self.free_edges) \
- self.entropy*ent
loss.backward()
if self.grad_norm_clip > 0:
torch.nn.utils.clip_grad_norm_(edge_params.grad, self.grad_norm_clip)
optim.step()
with torch.no_grad():
if out.view(-1).shape[0] == 1:
r2 = -666
else:
r2 = r2_score(target_preds.detach().cpu().numpy().ravel(), out.detach().cpu().numpy().ravel())
if self.verbose:
print(f'iter: {iter} | loss: {loss.item():.4f} | mse: {mse.item():.4f} | r2: {r2:.3f} | active edges: {(edge_weight > 0.5).sum().item()} / {self.model.edge_index.size(1)} | entropy: {ent.item():.4f}', end='\r')
# Post-training evaluation with subset edges > 0.5
if self.verbose:
print() # New line after training progress
with torch.no_grad():
# Get final edge weights and create binary mask for edges > 0.5
final_edge_probs, _ = torch.nn.functional.softmax(edge_params.data, dim=0)
subset_mask = (final_edge_probs > 0.5).float()
# Evaluate performance using only edges > 0.5
subset_out = self.model(x, edge_mask=subset_mask.view(1, -1))
if targets is not None:
subset_out = subset_out[:, targets]
subset_mse = torch.nn.functional.mse_loss(subset_out, target_preds).item()
subset_r2 = r2_score(target_preds.detach().cpu().numpy().ravel(),
subset_out.detach().cpu().numpy().ravel())
# Calculate variance explained (R2 can be negative, so we also show raw correlation)
target_flat = target_preds.detach().cpu().numpy().ravel()
pred_flat = subset_out.detach().cpu().numpy().ravel()
correlation = np.corrcoef(target_flat, pred_flat)[0, 1]
variance_explained = correlation ** 2 if not np.isnan(correlation) else 0.0
num_selected_edges = (subset_mask > 0.5).sum().item()
total_edges = len(subset_mask)
print("="*50)
print("POST-TRAINING EVALUATION (edges > 0.5)")
print("="*50)
print(f"Selected edges: {num_selected_edges} / {total_edges} ({100*num_selected_edges/total_edges:.1f}%)")
print(f"MSE (subset): {subset_mse:.6f}")
print(f"R² (subset): {subset_r2:.4f}")
print(f"Variance explained: {variance_explained:.4f}")
print(f"Correlation: {correlation:.4f}")
print("="*50)
edge_scores, _ = torch.nn.functional.softmax(edge_params.data, dim=0).detach().cpu().numpy()
src,dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()]
edgedf = pd.DataFrame({'source':src, 'target':dst, 'score':edge_scores})
if return_weights:
return edgedf, edge_scores
else:
return edgedf
def _explain_nodes(self, x, targets=None, return_weights=False):
'''
Compute node-level attributions using gradient descent optimization.
Parameters
----------
x : torch.tensor
Input features to explain; in shape (B, I).
targets : list, optional
Target output indices to explain.
return_weights : bool, optional (default=False)
Whether to return raw weights along with the DataFrame.
Returns
-------
pd.DataFrame
Columns ['node', 'score'] for node attributions.
'''
weights = torch.stack((self.prior*torch.ones(self.model.num_nodes, dtype=torch.float32, device=self.device, requires_grad=True),
-self.prior*torch.ones(self.model.num_nodes, dtype=torch.float32, device=self.device, requires_grad=True)), dim=0)
node_params = torch.nn.Parameter(weights)
# optimize parameter mask with objective
crit = torch.nn.MSELoss()
optim = self.optimizer([node_params], lr=self.lr, weight_decay=self.weight_decay)
# calculate tau decay rate
tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters)
# get target predictions
with torch.no_grad():
target_preds = self.model(x)
if targets is not None:
target_preds = target_preds[:, targets]
for iter in range(self.iters):
optim.zero_grad()
tau = max(self.tau0 * tau_decay_rate**iter, self.min_tau)
node_weight, _ = torch.nn.functional.gumbel_softmax(node_params, dim=0, hard=self.hard, tau=tau)
out = self.model(x, node_mask=node_weight.view(1, -1))
if targets is not None:
out = out[:, targets]
mse = crit(out, target_preds)
node_probs, _ = torch.nn.functional.softmax(node_params, dim=0)
m = torch.distributions.Bernoulli(probs=node_probs)
ent = m.entropy().mean()
loss = mse \
+ self.beta*torch.maximum(torch.tensor([0.], device=x.device), node_weight.sum() - self.free_edges) \
- self.entropy*ent
loss.backward()
optim.step()
with torch.no_grad():
r2 = r2_score(target_preds.detach().cpu().numpy().ravel(), out.detach().cpu().numpy().ravel())
if self.verbose:
print(f'iter: {iter} | loss: {loss.item():.4f} | mse: {mse.item():.4f} | r2: {r2:.3f} | active nodes: {(node_weight > 0.5).sum().item()} / {self.model.num_nodes} | entropy: {ent.item():.4f}', end='\r')
# Post-training evaluation with subset nodes > 0.5
if self.verbose:
print() # New line after training progress
with torch.no_grad():
# Get final node weights and create binary mask for nodes > 0.5
final_node_probs, _ = torch.nn.functional.softmax(node_params.data, dim=0)
subset_mask = (final_node_probs > 0.5).float()
# Evaluate performance using only nodes > 0.5
subset_out = self.model(x, node_mask=subset_mask.view(1, -1))
if targets is not None:
subset_out = subset_out[:, targets]
subset_mse = torch.nn.functional.mse_loss(subset_out, target_preds).item()
subset_r2 = r2_score(target_preds.detach().cpu().numpy().ravel(),
subset_out.detach().cpu().numpy().ravel())
# Calculate variance explained (R2 can be negative, so we also show raw correlation)
target_flat = target_preds.detach().cpu().numpy().ravel()
pred_flat = subset_out.detach().cpu().numpy().ravel()
correlation = np.corrcoef(target_flat, pred_flat)[0, 1]
variance_explained = correlation ** 2 if not np.isnan(correlation) else 0.0
num_selected_nodes = (subset_mask > 0.5).sum().item()
total_nodes = len(subset_mask)
print("="*50)
print("POST-TRAINING EVALUATION (nodes > 0.5)")
print("="*50)
print(f"Selected nodes: {num_selected_nodes} / {total_nodes} ({100*num_selected_nodes/total_nodes:.1f}%)")
print(f"MSE (subset): {subset_mse:.6f}")
print(f"R² (subset): {subset_r2:.4f}")
print(f"Variance explained: {variance_explained:.4f}")
print(f"Correlation: {correlation:.4f}")
print("="*50)
node_scores, _ = torch.nn.functional.softmax(node_params.data, dim=0).detach().cpu().numpy()
node_names = np.array(self.model.homo_names)
nodedf = pd.DataFrame({'node': node_names, 'score': node_scores})
if return_weights:
return nodedf, node_scores
else:
return nodedf
[docs] def tune(self, x, target_ixs=None, min_r2=0.7, beta_step=1.5, max_trials=20,
tolerance=1e-3, verbose=True, target='edge', **explain_kwargs):
"""
Tune beta parameter starting from current value to find maximum sparsity while
maintaining minimum performance.
Starts from the user's initial beta and adjusts up/down based on performance:
- If R² >= min_r2: increase beta (more sparsity) until performance drops
- If R² < min_r2: decrease beta (less sparsity) until performance recovers
Much more efficient than wide search since user provides good starting point.
Args:
x : torch.Tensor
Input data for explanation
target_ixs : list, optional
Target output indices to explain
min_r2 : float, optional (default=0.7)
Minimum R² threshold to maintain
beta_step : float, optional (default=1.5)
Multiplicative step size for beta adjustment (1.5 = 50% increase/decrease)
max_trials : int, optional (default=20)
Maximum number of beta adjustments to try
tolerance : float, optional (default=1e-3)
Convergence tolerance for fine search
verbose : bool, optional (default=True)
Whether to print search progress
target : str, optional (default='edge')
Whether to tune for 'edge' or 'node' level attributions
**explain_kwargs : dict, optional
Override any explainer parameters during tuning:
- iters: number of optimization steps
- lr: learning rate
- weight_decay: weight decay
- free_edges: elements allowed before penalty
- prior: initial bias for element selection
- tau0: initial temperature
- min_tau: minimum temperature
- hard: use straight-through estimator
- entropy: entropy bonus strength
Returns:
dict: Results containing optimal beta, achieved R², number of elements, and final DataFrame
"""
if target not in ['edge', 'node']:
raise ValueError(f"target must be 'edge' or 'node', got '{target}'")
if verbose:
print("="*60)
print("BETA TUNING - Starting from User's Beta")
print("="*60)
print(f"Target: Find max beta with R² >= {min_r2:.3f}")
print(f"Explanation target: {target}")
print(f"Starting beta: {self.beta:.4f}")
print(f"Step size: {beta_step:.2f}x")
if explain_kwargs:
print(f"Parameter overrides: {explain_kwargs}")
print("="*60)
# Store original settings for all tunable parameters
original_settings = {
'beta': self.beta,
'iters': self.iters,
'lr': self.lr,
'weight_decay': self.weight_decay,
'free_edges': self.free_edges,
'prior': self.prior,
'tau0': self.tau0,
'min_tau': self.min_tau,
'hard': self.hard,
'entropy': self.entropy,
'verbose': self.verbose
}
# Apply parameter overrides
for param, value in explain_kwargs.items():
if hasattr(self, param):
setattr(self, param, value)
else:
if verbose:
print(f"Warning: Unknown parameter '{param}' ignored")
# Disable verbose during tuning iterations unless specifically requested
tuning_verbose = self.verbose if 'verbose' in explain_kwargs else False
def evaluate_beta(beta_val):
"""Evaluate performance for a given beta value"""
if target == 'edge':
# Initialize edge parameters
num_elements = self.model.edge_index.size(1)
weights = torch.stack((self.prior*torch.ones(num_elements, dtype=torch.float32, device=self.device),
-self.prior*torch.ones(num_elements, dtype=torch.float32, device=self.device)), dim=0)
params = torch.nn.Parameter(weights)
# Setup training
crit = torch.nn.MSELoss()
optim = self.optimizer([params], lr=self.lr, weight_decay=self.weight_decay)
tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters)
# Get target predictions
with torch.no_grad():
target_preds = self.model(x)
if target_ixs is not None:
target_preds = target_preds[:, target_ixs]
# Run training
for iter in range(self.iters):
optim.zero_grad()
tau = max(self.tau0 * tau_decay_rate**iter, self.min_tau)
weight, _ = torch.nn.functional.gumbel_softmax(params, dim=0, hard=self.hard, tau=tau)
out = self.model(x, edge_mask=weight.view(1, -1))
if target_ixs is not None:
out = out[:, target_ixs]
mse = crit(out, target_preds)
probs, _ = torch.nn.functional.softmax(params, dim=0)
m = torch.distributions.Bernoulli(probs=probs)
ent = m.entropy().mean()
loss = mse + beta_val*torch.maximum(torch.tensor([0.], device=x.device), weight.sum() - self.free_edges) - self.entropy*ent
loss.backward()
optim.step()
if tuning_verbose and iter % 50 == 0:
with torch.no_grad():
r2 = r2_score(target_preds.detach().cpu().numpy().ravel(), out.detach().cpu().numpy().ravel())
print(f' iter: {iter} | loss: {loss.item():.4f} | r2: {r2:.3f} | beta: {beta_val:.4f}')
# Evaluate final performance on subset
with torch.no_grad():
final_probs, _ = torch.nn.functional.softmax(params.data, dim=0)
subset_mask = (final_probs > 0.5).float()
subset_out = self.model(x, edge_mask=subset_mask.view(1, -1))
if target_ixs is not None:
subset_out = subset_out[:, target_ixs]
subset_r2 = r2_score(target_preds.detach().cpu().numpy().ravel(),
subset_out.detach().cpu().numpy().ravel())
num_elements = (subset_mask > 0.5).sum().item()
else: # target == 'node'
# Initialize node parameters
num_elements = self.model.num_nodes
weights = torch.stack((self.prior*torch.ones(num_elements, dtype=torch.float32, device=self.device),
-self.prior*torch.ones(num_elements, dtype=torch.float32, device=self.device)), dim=0)
params = torch.nn.Parameter(weights)
# Setup training
crit = torch.nn.MSELoss()
optim = self.optimizer([params], lr=self.lr, weight_decay=self.weight_decay)
tau_decay_rate = (self.min_tau / self.tau0) ** (1 / self.iters)
# Get target predictions
with torch.no_grad():
target_preds = self.model(x)
if target_ixs is not None:
target_preds = target_preds[:, target_ixs]
# Run training
for iter in range(self.iters):
optim.zero_grad()
tau = max(self.tau0 * tau_decay_rate**iter, self.min_tau)
weight, _ = torch.nn.functional.gumbel_softmax(params, dim=0, hard=self.hard, tau=tau)
out = self.model(x, node_mask=weight.view(1, -1))
if target_ixs is not None:
out = out[:, target_ixs]
mse = crit(out, target_preds)
probs, _ = torch.nn.functional.softmax(params, dim=0)
m = torch.distributions.Bernoulli(probs=probs)
ent = m.entropy().mean()
loss = mse + beta_val*torch.maximum(torch.tensor([0.], device=x.device), weight.sum() - self.free_edges) - self.entropy*ent
loss.backward()
optim.step()
if tuning_verbose and iter % 50 == 0:
with torch.no_grad():
r2 = r2_score(target_preds.detach().cpu().numpy().ravel(), out.detach().cpu().numpy().ravel())
print(f' iter: {iter} | loss: {loss.item():.4f} | r2: {r2:.3f} | beta: {beta_val:.4f}')
# Evaluate final performance on subset
with torch.no_grad():
final_probs, _ = torch.nn.functional.softmax(params.data, dim=0)
subset_mask = (final_probs > 0.5).float()
subset_out = self.model(x, node_mask=subset_mask.view(1, -1))
if target_ixs is not None:
subset_out = subset_out[:, target_ixs]
subset_r2 = r2_score(target_preds.detach().cpu().numpy().ravel(),
subset_out.detach().cpu().numpy().ravel())
num_elements = (subset_mask > 0.5).sum().item()
return subset_r2, num_elements, params
# Adaptive search starting from user's beta
current_beta = self.beta
best_beta = current_beta
best_r2 = 0.0
total_elements = self.model.edge_index.size(1) if target == 'edge' else self.model.num_nodes
best_elements = total_elements
best_params = None
# Step 1: Evaluate starting point
if verbose:
print(f"\nStep 1: Evaluating starting beta = {current_beta:.4f}")
try:
initial_r2, initial_elements, initial_params = evaluate_beta(current_beta)
element_type = "Edges" if target == 'edge' else "Nodes"
if verbose:
print(f" → R² = {initial_r2:.4f}, {element_type} = {initial_elements}")
# Set initial best
best_beta = current_beta
best_r2 = initial_r2
best_elements = initial_elements
best_params = initial_params
# Step 2: Determine search direction
if initial_r2 >= min_r2:
# Performance is good, try increasing beta (more sparsity)
search_direction = "up"
if verbose:
print(f" ✓ Good performance! Searching upward for more sparsity...")
else:
# Performance is poor, try decreasing beta (less sparsity)
search_direction = "down"
if verbose:
print(f" ✗ Poor performance! Searching downward for better performance...")
# Step 3: Search in determined direction
for trial in range(max_trials):
if search_direction == "up":
test_beta = current_beta * beta_step
else:
test_beta = current_beta / beta_step
if verbose:
print(f"\nTrial {trial + 1}: Testing beta = {test_beta:.4f} (direction: {search_direction})")
try:
test_r2, test_elements, test_params = evaluate_beta(test_beta)
if verbose:
print(f" → R² = {test_r2:.4f}, {element_type} = {test_elements}")
if search_direction == "up":
if test_r2 >= min_r2:
# Still good, keep this as best and continue
best_beta = test_beta
best_r2 = test_r2
best_elements = test_elements
best_params = test_params
current_beta = test_beta
if verbose:
print(f" ✓ Still good! New best: β={best_beta:.4f}")
else:
# Performance dropped, we've found the boundary
if verbose:
print(f" ✗ Performance dropped, boundary found!")
break
else: # search_direction == "down"
if test_r2 >= min_r2:
# Found good performance, this is our answer
best_beta = test_beta
best_r2 = test_r2
best_elements = test_elements
best_params = test_params
if verbose:
print(f" ✓ Performance recovered! Optimal: β={best_beta:.4f}")
break
else:
# Still poor, keep going down
current_beta = test_beta
if verbose:
print(f" ✗ Still poor, continuing downward...")
# Safety check - don't let beta get too extreme
if test_beta > 100 or test_beta < 0.001:
if verbose:
print(f" ⚠ Beta limit reached ({test_beta:.4f}), stopping search")
break
except Exception as e:
if verbose:
print(f" Error with beta={test_beta:.4f}: {e}")
break
except Exception as e:
if verbose:
print(f"Error with initial beta={current_beta:.4f}: {e}")
# Fall back to original beta if there's an error
best_beta = self.beta
# Restore all original settings
for param, value in original_settings.items():
setattr(self, param, value)
# Set the optimal beta
self.beta = best_beta
# Create final dataframe with optimal results
final_df = None
if best_params is not None:
scores, _ = torch.nn.functional.softmax(best_params.data, dim=0).detach().cpu().numpy()
if target == 'edge':
src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()]
final_df = pd.DataFrame({'source': src, 'target': dst, 'score': scores})
else: # target == 'node'
node_names = np.array(self.model.homo_names)
final_df = pd.DataFrame({'node': node_names, 'score': scores})
# Final evaluation with optimal beta
if verbose:
print("\n" + "="*60)
print("TUNING COMPLETE")
print("="*60)
print(f"Starting beta: {original_settings['beta']:.4f}")
print(f"Optimal beta: {best_beta:.4f}")
print(f"Change: {best_beta/original_settings['beta']:.2f}x")
print(f"Final R²: {best_r2:.4f}")
element_type_lower = "edges" if target == 'edge' else "nodes"
print(f"Selected {element_type_lower}: {best_elements} / {total_elements} ({100*best_elements/total_elements:.1f}%)")
print("="*60)
results = {
'starting_beta': original_settings['beta'],
'optimal_beta': best_beta,
'beta_change_factor': best_beta / original_settings['beta'],
'achieved_r2': best_r2,
'num_elements': best_elements,
'total_elements': total_elements,
'sparsity_ratio': best_elements / total_elements,
'result_df': final_df,
'target': target
}
return results