gsnn.gsnn.interpret.GSNNExplainer
- class gsnn.gsnn.interpret.GSNNExplainer(model, data, ignore_cuda=False, gumbel_softmax=True, hard=False, tau0=3, min_tau=0.5, prior=1, iters=250, lr=0.01, weight_decay=1e-05, free_edges=0, grad_norm_clip=0, beta=1, verbose=True, optimizer=torch.optim.Adam, entropy=0)[source]
Bases:
objectEdge/node mask optimiser that produces sparse explanations.
The explainer learns a binary mask m∈{0,1}^{E|N} that maximises fidelity between the model’s prediction on the masked graph and the prediction on the full graph while simultaneously penalising mask size:
L = MSE\bigl(f(x; m), f(x; 1)\bigr) + β \max(0, \|m\|₁ − free_elements) − λ H(m) (optional entropy term)Here m is obtained via a differentiable Gumbel-Softmax relaxation so the optimisation can be performed with vanilla back-prop. After convergence the importance score is the softmax probability p_i = P(m_i=1).
score_i → 1element i is essential for reproducing the original prediction.score_i → 0element i can be removed with little impact.
- Parameters:
model (torch.nn.Module) – Trained GSNN model (its parameters are frozen during explanation).
data (torch_geometric.data.Data) – Graph data object (only metadata are used).
ignore_cuda (bool, optional (default=False)) – Force CPU even if CUDA is available.
gumbel_softmax (bool, optional (default=True)) – Use the Gumbel-Softmax re-parameterisation; otherwise plain Softmax.
hard (bool, optional (default=False)) – Use the straight-through estimator to obtain discrete masks at test time while keeping gradients continuous.
tau0 (float, optional (default=3.0)) – Initial temperature for the (hard) Gumbel-Softmax.
min_tau (float, optional (default=0.5)) – Minimum temperature reached after exponential decay.
prior (float, optional (default=1.0)) – Initial bias added to the positive/negative logits.
iters (int, optional (default=250)) – Number of optimisation steps.
lr (float, optional (default=1e-2)) – Learning rate for the optimiser.
weight_decay (float, optional (default=1e-5)) – Weight decay applied to the mask logits.
free_edges (int, optional (default=0)) – Number of elements allowed before the sparsity penalty activates.
beta (float, optional (default=1.0)) – Coefficient of the sparsity term.
entropy (float, optional (default=0.0)) – Strength of the entropy bonus (encourages exploration).
Example
>>> explainer = GSNNExplainer(model, data, iters=400, beta=5) >>> # Edge-level attributions >>> edge_df = explainer.explain(x, targets=[0], target='edge') >>> edge_df.sort_values('score', ascending=False).head() >>> # Node-level attributions >>> node_df = explainer.explain(x, targets=[0], target='node') >>> node_df.sort_values('score', ascending=False).head()
- __init__(model, data, ignore_cuda=False, gumbel_softmax=True, hard=False, tau0=3, min_tau=0.5, prior=1, iters=250, lr=0.01, weight_decay=1e-05, free_edges=0, grad_norm_clip=0, beta=1, verbose=True, optimizer=torch.optim.Adam, entropy=0)[source]
Adapted from the methods presented in GNNExplainer (https://arxiv.org/abs/1903.03894).
- Parameters:
Model (model torch.nn.Module GSNN) –
data (data pyg.Data GSNN processed graph) –
edges (beta float regularization scalar encouraging a minimal subset of) –
available (ignore_cuda bool whether to use cuda if) –
gumbel-softmax (min_tau float minimum temperature value for) –
gumbel-softmax –
gumbel-softmax –
selecting (prior float prior strength to initialize theta; value of 0 will make each element 0.5 prob of being) –
selected. (value > 0 will make it more likely to be) –
value (grad_norm_clip float gradient norm clipping) –
optimisation (verbose bool whether to print progress information during) –
training (optimizer torch.optim.Optimizer optimizer to use for) –
strength (entropy float entropy bonus) –
steps (iters int number of optimisation) –
optimiser (weight_decay float weight decay for the) –
optimiser –
activates (free_edges int number of edges allowed before the sparsity penalty) –
- Returns
None
Methods
__init__(model, data[, ignore_cuda, ...])Adapted from the methods presented in GNNExplainer (https://arxiv.org/abs/1903.03894).
explain(x[, target_idx, return_weights, target])Initializes and runs gradient descent to select a minimal subset of edges or nodes that produce comparable predictions to the full graph.
tune(x[, target_ixs, min_r2, beta_step, ...])Tune beta parameter starting from current value to find maximum sparsity while maintaining minimum performance.
- explain(x, target_idx=None, return_weights=False, target='edge')[source]
Initializes and runs gradient descent to select a minimal subset of edges or nodes that produce comparable predictions to the full graph.
- Parameters:
x (torch.tensor) – Input features to explain; in shape (B, I).
targets (list, optional) – Target output indices to explain.
return_weights (bool, optional (default=False)) – Whether to return raw weights along with the DataFrame.
target (str, optional (default='edge')) – Whether to return ‘edge’ or ‘node’ level attributions.
- Returns:
If target=’edge’: columns [‘source’, ‘target’, ‘score’] for edge attributions. If target=’node’: columns [‘node’, ‘score’] for node attributions.
- Return type:
pd.DataFrame
- tune(x, target_ixs=None, min_r2=0.7, beta_step=1.5, max_trials=20, tolerance=0.001, verbose=True, target='edge', **explain_kwargs)[source]
Tune beta parameter starting from current value to find maximum sparsity while maintaining minimum performance.
Starts from the user’s initial beta and adjusts up/down based on performance: - If R² >= min_r2: increase beta (more sparsity) until performance drops - If R² < min_r2: decrease beta (less sparsity) until performance recovers
Much more efficient than wide search since user provides good starting point.
- Parameters:
x – torch.Tensor Input data for explanation
target_ixs – list, optional Target output indices to explain
min_r2 – float, optional (default=0.7) Minimum R² threshold to maintain
beta_step – float, optional (default=1.5) Multiplicative step size for beta adjustment (1.5 = 50% increase/decrease)
max_trials – int, optional (default=20) Maximum number of beta adjustments to try
tolerance – float, optional (default=1e-3) Convergence tolerance for fine search
verbose – bool, optional (default=True) Whether to print search progress
target – str, optional (default=’edge’) Whether to tune for ‘edge’ or ‘node’ level attributions
**explain_kwargs – dict, optional Override any explainer parameters during tuning: - iters: number of optimization steps - lr: learning rate - weight_decay: weight decay - free_edges: elements allowed before penalty - prior: initial bias for element selection - tau0: initial temperature - min_tau: minimum temperature - hard: use straight-through estimator - entropy: entropy bonus strength
- Returns:
Results containing optimal beta, achieved R², number of elements, and final DataFrame
- Return type: