gsnn.gsnn.interpret.ContrastiveGSNNExplainer

Classes

ContrastiveGSNNExplainer(model, data[, ...])

Edge/node mask optimiser for contrastive explanations.

class gsnn.gsnn.interpret.ContrastiveGSNNExplainer.ContrastiveGSNNExplainer(model, data, ignore_cuda: bool = False, gumbel_softmax: bool = True, hard: bool = False, tau0: float = 3.0, min_tau: float = 0.5, prior: float = 1.0, iters: int = 250, lr: float = 0.01, weight_decay: float = 1e-05, free_edges: int = 0, beta: float = 1.0, verbose: bool = True, optimizer=torch.optim.Adam, entropy: float = 0.0)[source]

Bases: object

Edge/node mask optimiser for contrastive explanations.

This explainer learns a binary mask m∈{0,1}^{E|N} that maximises fidelity between the prediction difference on the masked graph and the difference on the full graph, while simultaneously penalising mask size:

Δf(m) = f(x₁; m)[target_idx] − f(x₂; m)[target_idx]   (multivariate)

L = MSE(Δf(m), Δf(1))      # over all B×T elements
    + β max(0, ‖m‖₁ − free_elements)
    − λ H(m)               (optional entropy term)

Here m is obtained via a differentiable Gumbel-Softmax relaxation so the optimisation can be performed with vanilla back-prop. After convergence the importance score is the softmax probability p_i = P(m_i=1).

  • score_i 1 element i is essential for reproducing the prediction difference.

  • score_i 0 element i can be removed without affecting the difference.

Parameters:
  • model (torch.nn.Module) – Trained GSNN model (its parameters are frozen during explanation).

  • data (torch_geometric.data.Data) – Graph data object (only metadata are used).

  • ignore_cuda (bool, optional (default=False)) – Force CPU even if CUDA is available.

  • gumbel_softmax (bool, optional (default=True)) – Use the Gumbel-Softmax re-parameterisation; otherwise plain Softmax.

  • hard (bool, optional (default=False)) – Use the straight-through estimator to obtain discrete masks at test time while keeping gradients continuous.

  • tau0 (float, optional (default=3.0)) – Initial temperature for the (hard) Gumbel-Softmax.

  • min_tau (float, optional (default=0.5)) – Minimum temperature reached after exponential decay.

  • prior (float, optional (default=1.0)) – Initial bias added to the positive/negative logits.

  • iters (int, optional (default=250)) – Number of optimisation steps.

  • lr (float, optional (default=1e-2)) – Learning rate for the optimiser.

  • weight_decay (float, optional (default=1e-5)) – Weight decay applied to the mask logits.

  • free_edges (int, optional (default=0)) – Number of elements allowed before the sparsity penalty activates.

  • beta (float, optional (default=1.0)) – Coefficient of the sparsity term.

  • entropy (float, optional (default=0.0)) – Strength of the entropy bonus (encourages exploration).

  • verbose (bool, optional (default=True)) – Print progress information during optimisation.

Example

>>> explainer = ContrastiveGSNNExplainer(model, data, iters=400, beta=5)
>>> # Edge-level attributions
>>> edge_df = explainer.explain(x1, x2, target_idx=0, target='edge')
>>> edge_df.sort_values('score', ascending=False).head()
>>> # Node-level attributions
>>> node_df = explainer.explain(x1, x2, target_idx=0, target='node')
>>> node_df.sort_values('score', ascending=False).head()
explain(x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], *, return_weights: bool = False, target: str = 'edge') pandas.DataFrame[source]

Compute attributions for f(x₁) − f(x₂).

Initializes and runs gradient descent to select a minimal subset of elements that preserve the prediction difference between x1 and x2.

When given multiple pairs (batch), learns ONE mask that works well across ALL pairs by treating the differences as a multi-output objective. This is much faster than per-sample optimization.

Parameters:
  • x1 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch)) – Two input feature tensors. They must have identical batch size. When B > 1, learns a single mask that preserves the prediction difference across all pairs simultaneously.

  • x2 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch)) – Two input feature tensors. They must have identical batch size. When B > 1, learns a single mask that preserves the prediction difference across all pairs simultaneously.

  • target_idx (int or list[int]) – Output dimension(s) to explain. If a list is provided the attributions refer to the sum of those outputs.

  • return_weights (bool, optional (default=False)) – Whether to return raw weights along with the DataFrame.

  • target (str, optional (default='edge')) – Whether to return ‘edge’ or ‘node’ level attributions.

Returns:

If target=’edge’: columns [‘source’, ‘target’, ‘score’] for edge attributions. If target=’node’: columns [‘node’, ‘score’] for node attributions.

Return type:

pd.DataFrame

tune(x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]] = None, min_fidelity: float = 0.9, beta_step: float = 1.5, max_trials: int = 20, verbose: bool = True, target: str = 'edge', **explain_kwargs)[source]

Tune beta parameter to find maximum sparsity while maintaining fidelity.

For contrastive explanations, fidelity is measured as how well the subset preserves the prediction difference |f(x1) - f(x2)| across all pairs.

Parameters:
  • x1 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in])) – Input data pairs for explanation. When B > 1, learns a single mask that works well across all pairs simultaneously.

  • x2 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in])) – Input data pairs for explanation. When B > 1, learns a single mask that works well across all pairs simultaneously.

  • target_idx (int or list[int], optional) – Target output indices to explain.

  • min_fidelity (float, optional (default=0.9)) – Minimum fidelity threshold (1 - mean_relative_error) to maintain.

  • beta_step (float, optional (default=1.5)) – Multiplicative step size for beta adjustment.

  • max_trials (int, optional (default=20)) – Maximum number of beta adjustments to try.

  • verbose (bool, optional (default=True)) – Whether to print search progress.

  • target (str, optional (default='edge')) – Whether to tune for ‘edge’ or ‘node’ level attributions.

  • **explain_kwargs (dict, optional) – Override any explainer parameters during tuning.

Returns:

Results containing optimal beta, achieved fidelity, number of elements, and final DataFrame.

Return type:

dict