gsnn.interpret.ContrastiveGSNNExplainer

Functions

normalize_model_kwargs(model_kwargs)

Return a fresh dict (possibly empty) suitable for **kwargs splatting.

Classes

ContrastiveGSNNExplainer(model, data[, ...])

Edge/node mask optimiser for contrastive explanations.

class gsnn.interpret.ContrastiveGSNNExplainer.ContrastiveGSNNExplainer(model, data, ignore_cuda: bool = False, gumbel_softmax: bool = True, hard: bool = False, tau0: float = 3.0, min_tau: float = 0.5, prior: float = 1.0, iters: int = 250, lr: float = 0.01, weight_decay: float = 1e-05, free_edges: int = 0, beta: float = 1.0, verbose: bool = True, optimizer=torch.optim.Adam, entropy: float = 0.0, scale_mse_by_variance: bool = True)[source]

Bases: object

Edge/node mask optimiser for contrastive explanations.

This explainer learns a binary mask m∈{0,1}^{E|N} that maximises fidelity between the absolute prediction difference on the masked graph and the absolute difference on the full graph, while simultaneously penalising mask size:

|Δf|(m) = | f(x₁; m)[target_idx] − f(x₂; m)[target_idx] |   (multivariate)

L = MSE(|Δf|(m), |Δf|(1))      # over all B×T elements
    + β max(0, ‖m‖₁ − free_elements)
    − λ H(m)                     (optional entropy term)

Here m is obtained via a differentiable Gumbel-Softmax relaxation so the optimisation can be performed with vanilla back-prop. After convergence the importance score is the softmax probability p_i = P(m_i=1).

  • score_i 1 element i is essential for reproducing the magnitude of the prediction difference |Δf|.

  • score_i 0 element i can be removed without affecting |Δf|.

Note

The objective targets the magnitude |Δf| rather than the signed difference Δf. A mask that flips the sign of the prediction difference while preserving its magnitude is therefore considered faithful by this objective. In practice this is rarely an issue because optimisation starts from the full mask (where the sign of Δf is correct) and progressively sparsifies, but it is worth being aware of for pathological models or aggressive beta.

Parameters:
  • model (torch.nn.Module) – Trained GSNN model (its parameters are frozen during explanation).

  • data (torch_geometric.data.Data) – Graph data object (only metadata are used).

  • ignore_cuda (bool, optional (default=False)) – Force CPU even if CUDA is available.

  • gumbel_softmax (bool, optional (default=True)) – Use the Gumbel-Softmax re-parameterisation; otherwise plain Softmax.

  • hard (bool, optional (default=False)) – Use the straight-through estimator to obtain discrete masks at test time while keeping gradients continuous.

  • tau0 (float, optional (default=3.0)) – Initial temperature for the (hard) Gumbel-Softmax.

  • min_tau (float, optional (default=0.5)) – Minimum temperature reached after exponential decay.

  • prior (float, optional (default=1.0)) – Initial bias added to the positive/negative logits.

  • iters (int, optional (default=250)) – Number of optimisation steps.

  • lr (float, optional (default=1e-2)) – Learning rate for the optimiser.

  • weight_decay (float, optional (default=1e-5)) – Weight decay applied to the mask logits.

  • free_edges (int, optional (default=0)) – Number of elements allowed before the sparsity penalty activates.

  • beta (float, optional (default=1.0)) – Coefficient of the sparsity term.

  • entropy (float, optional (default=0.0)) – Strength of the entropy bonus (encourages exploration).

  • verbose (bool, optional (default=True)) – Print progress information during optimisation.

  • scale_mse_by_variance (bool, optional (default=True)) – If True, normalise the MSE term by Var(target_diffs) so that the fidelity loss is scale-invariant across pairs (an 1 - style objective). This makes beta interpretable across batches whose Δf magnitudes differ. Falls back to plain MSE when the target difference tensor has fewer than 2 elements.

Example

>>> explainer = ContrastiveGSNNExplainer(model, data, iters=400, beta=5)
>>> # Edge-level attributions
>>> edge_df = explainer.explain(x1, x2, target_idx=0, target='edge')
>>> edge_df.sort_values('score', ascending=False).head()
>>> # Node-level attributions
>>> node_df = explainer.explain(x1, x2, target_idx=0, target='node')
>>> node_df.sort_values('score', ascending=False).head()
explain(x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], *, return_weights: bool = False, target: str = 'edge', model_kwargs1=None, model_kwargs2=None) pandas.DataFrame[source]

Compute attributions for f(x₁) − f(x₂).

Initializes and runs gradient descent to select a minimal subset of elements that preserve the prediction difference between x1 and x2.

When given multiple pairs (batch), learns ONE mask that works well across ALL pairs by treating the differences as a multi-output objective. This is much faster than per-sample optimization.

Parameters:
  • x1 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch)) – Two input feature tensors. They must have identical batch size. When B > 1, learns a single mask that preserves the prediction difference across all pairs simultaneously.

  • x2 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch)) – Two input feature tensors. They must have identical batch size. When B > 1, learns a single mask that preserves the prediction difference across all pairs simultaneously.

  • target_idx (int or list[int]) – Output dimension(s) to explain. If a list is provided the attributions refer to the sum of those outputs.

  • return_weights (bool, optional (default=False)) – Whether to return raw weights along with the DataFrame.

  • target (str, optional (default='edge')) – Whether to return ‘edge’ or ‘node’ level attributions.

  • model_kwargs1 (dict, optional (default=None)) – Per-side keyword arguments forwarded to every self.model(x1, ...) / self.model(x2, ...) call (e.g. {'x_fn': x_fn_1} for models trained with node_activity=True). Tensor values must have leading dim equal to x1.shape[0] / x2.shape[0]. edge_mask / node_mask are reserved.

  • model_kwargs2 (dict, optional (default=None)) – Per-side keyword arguments forwarded to every self.model(x1, ...) / self.model(x2, ...) call (e.g. {'x_fn': x_fn_1} for models trained with node_activity=True). Tensor values must have leading dim equal to x1.shape[0] / x2.shape[0]. edge_mask / node_mask are reserved.

Returns:

If target=’edge’: columns [‘source’, ‘target’, ‘score’] for edge attributions. If target=’node’: columns [‘node’, ‘score’] for node attributions.

Return type:

pd.DataFrame

tune(x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]] = None, min_fidelity: float = 0.9, beta_step: float = 1.5, max_trials: int = 20, verbose: bool = True, target: str = 'edge', **explain_kwargs)[source]

Tune beta parameter to find maximum sparsity while maintaining fidelity.

For contrastive explanations, fidelity is measured as how well the subset preserves the prediction difference |f(x1) - f(x2)| across all pairs.

Parameters:
  • x1 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in])) – Input data pairs for explanation. When B > 1, learns a single mask that works well across all pairs simultaneously.

  • x2 (torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in])) – Input data pairs for explanation. When B > 1, learns a single mask that works well across all pairs simultaneously.

  • target_idx (int or list[int], optional) – Target output indices to explain.

  • min_fidelity (float, optional (default=0.9)) – Minimum fidelity threshold (1 - mean_relative_error) to maintain.

  • beta_step (float, optional (default=1.5)) – Multiplicative step size for beta adjustment.

  • max_trials (int, optional (default=20)) – Maximum number of beta adjustments to try.

  • verbose (bool, optional (default=True)) – Whether to print search progress.

  • target (str, optional (default='edge')) – Whether to tune for ‘edge’ or ‘node’ level attributions.

  • **explain_kwargs (dict, optional) – Override any explainer parameters during tuning.

Returns:

Results containing optimal beta, achieved fidelity, number of elements, and final DataFrame.

Return type:

dict