gsnn.gsnn.interpret.GSNNExplainer

class gsnn.gsnn.interpret.GSNNExplainer(model, data, ignore_cuda=False, gumbel_softmax=True, hard=False, tau0=3, min_tau=0.5, prior=1, iters=250, lr=0.01, weight_decay=1e-05, free_edges=0, grad_norm_clip=0, beta=1, verbose=True, optimizer=torch.optim.Adam, entropy=0)[source]

Bases: object

Edge/node mask optimiser that produces sparse explanations.

The explainer learns a binary mask m∈{0,1}^{E|N} that maximises fidelity between the model’s prediction on the masked graph and the prediction on the full graph while simultaneously penalising mask size:

L = MSE\bigl(f(x; m), f(x; 1)\bigr)
    + β \max(0, \|m\|₁ − free_elements)
    − λ H(m)            (optional entropy term)

Here m is obtained via a differentiable Gumbel-Softmax relaxation so the optimisation can be performed with vanilla back-prop. After convergence the importance score is the softmax probability p_i = P(m_i=1).

  • score_i 1 element i is essential for reproducing the original prediction.

  • score_i 0 element i can be removed with little impact.

Parameters:
  • model (torch.nn.Module) – Trained GSNN model (its parameters are frozen during explanation).

  • data (torch_geometric.data.Data) – Graph data object (only metadata are used).

  • ignore_cuda (bool, optional (default=False)) – Force CPU even if CUDA is available.

  • gumbel_softmax (bool, optional (default=True)) – Use the Gumbel-Softmax re-parameterisation; otherwise plain Softmax.

  • hard (bool, optional (default=False)) – Use the straight-through estimator to obtain discrete masks at test time while keeping gradients continuous.

  • tau0 (float, optional (default=3.0)) – Initial temperature for the (hard) Gumbel-Softmax.

  • min_tau (float, optional (default=0.5)) – Minimum temperature reached after exponential decay.

  • prior (float, optional (default=1.0)) – Initial bias added to the positive/negative logits.

  • iters (int, optional (default=250)) – Number of optimisation steps.

  • lr (float, optional (default=1e-2)) – Learning rate for the optimiser.

  • weight_decay (float, optional (default=1e-5)) – Weight decay applied to the mask logits.

  • free_edges (int, optional (default=0)) – Number of elements allowed before the sparsity penalty activates.

  • beta (float, optional (default=1.0)) – Coefficient of the sparsity term.

  • entropy (float, optional (default=0.0)) – Strength of the entropy bonus (encourages exploration).

Example

>>> explainer = GSNNExplainer(model, data, iters=400, beta=5)
>>> # Edge-level attributions
>>> edge_df = explainer.explain(x, targets=[0], target='edge')
>>> edge_df.sort_values('score', ascending=False).head()
>>> # Node-level attributions
>>> node_df = explainer.explain(x, targets=[0], target='node')
>>> node_df.sort_values('score', ascending=False).head()
__init__(model, data, ignore_cuda=False, gumbel_softmax=True, hard=False, tau0=3, min_tau=0.5, prior=1, iters=250, lr=0.01, weight_decay=1e-05, free_edges=0, grad_norm_clip=0, beta=1, verbose=True, optimizer=torch.optim.Adam, entropy=0)[source]

Adapted from the methods presented in GNNExplainer (https://arxiv.org/abs/1903.03894).

Parameters:
  • Model (model torch.nn.Module GSNN) –

  • data (data pyg.Data GSNN processed graph) –

  • edges (beta float regularization scalar encouraging a minimal subset of) –

  • available (ignore_cuda bool whether to use cuda if) –

  • gumbel-softmax (min_tau float minimum temperature value for) –

  • gumbel-softmax

  • gumbel-softmax

  • selecting (prior float prior strength to initialize theta; value of 0 will make each element 0.5 prob of being) –

  • selected. (value > 0 will make it more likely to be) –

  • value (grad_norm_clip float gradient norm clipping) –

  • optimisation (verbose bool whether to print progress information during) –

  • training (optimizer torch.optim.Optimizer optimizer to use for) –

  • strength (entropy float entropy bonus) –

  • steps (iters int number of optimisation) –

  • optimiser (weight_decay float weight decay for the) –

  • optimiser

  • activates (free_edges int number of edges allowed before the sparsity penalty) –

Returns

None

Methods

__init__(model, data[, ignore_cuda, ...])

Adapted from the methods presented in GNNExplainer (https://arxiv.org/abs/1903.03894).

explain(x[, target_idx, return_weights, target])

Initializes and runs gradient descent to select a minimal subset of edges or nodes that produce comparable predictions to the full graph.

tune(x[, target_ixs, min_r2, beta_step, ...])

Tune beta parameter starting from current value to find maximum sparsity while maintaining minimum performance.

explain(x, target_idx=None, return_weights=False, target='edge')[source]

Initializes and runs gradient descent to select a minimal subset of edges or nodes that produce comparable predictions to the full graph.

Parameters:
  • x (torch.tensor) – Input features to explain; in shape (B, I).

  • targets (list, optional) – Target output indices to explain.

  • return_weights (bool, optional (default=False)) – Whether to return raw weights along with the DataFrame.

  • target (str, optional (default='edge')) – Whether to return ‘edge’ or ‘node’ level attributions.

Returns:

If target=’edge’: columns [‘source’, ‘target’, ‘score’] for edge attributions. If target=’node’: columns [‘node’, ‘score’] for node attributions.

Return type:

pd.DataFrame

tune(x, target_ixs=None, min_r2=0.7, beta_step=1.5, max_trials=20, tolerance=0.001, verbose=True, target='edge', **explain_kwargs)[source]

Tune beta parameter starting from current value to find maximum sparsity while maintaining minimum performance.

Starts from the user’s initial beta and adjusts up/down based on performance: - If R² >= min_r2: increase beta (more sparsity) until performance drops - If R² < min_r2: decrease beta (less sparsity) until performance recovers

Much more efficient than wide search since user provides good starting point.

Parameters:
  • x – torch.Tensor Input data for explanation

  • target_ixs – list, optional Target output indices to explain

  • min_r2 – float, optional (default=0.7) Minimum R² threshold to maintain

  • beta_step – float, optional (default=1.5) Multiplicative step size for beta adjustment (1.5 = 50% increase/decrease)

  • max_trials – int, optional (default=20) Maximum number of beta adjustments to try

  • tolerance – float, optional (default=1e-3) Convergence tolerance for fine search

  • verbose – bool, optional (default=True) Whether to print search progress

  • target – str, optional (default=’edge’) Whether to tune for ‘edge’ or ‘node’ level attributions

  • **explain_kwargs – dict, optional Override any explainer parameters during tuning: - iters: number of optimization steps - lr: learning rate - weight_decay: weight decay - free_edges: elements allowed before penalty - prior: initial bias for element selection - tau0: initial temperature - min_tau: minimum temperature - hard: use straight-through estimator - entropy: entropy bonus strength

Returns:

Results containing optimal beta, achieved R², number of elements, and final DataFrame

Return type:

dict