gsnn.interpret.CounterfactualExplainer

class gsnn.interpret.CounterfactualExplainer(model: torch.nn.Module, data=None, ignore_cuda: bool = False)[source]

Bases: object

Feature-level counterfactual explainer using gradient descent.

This module learns a minimal perturbation δ to an input x such that:

f(x + δ) ≈ target_value

The perturbation is learned via gradient descent with L2 regularization to enforce minimality. The optimization objective is:

\[\min_δ \|f(x + δ) - \text{target}\|^2 + λ\|δ\|^2\]

where λ is the weight decay parameter controlling the trade-off between achieving the target and minimizing the perturbation.

  • δ_i > 0 feature i needs to be increased to reach the target.

  • δ_i < 0 feature i needs to be decreased to reach the target.

  • δ_i 0 feature i is irrelevant for the counterfactual.

Parameters:
  • model (torch.nn.Module) – Trained GSNN model (evaluation mode is enforced internally).

  • data (torch_geometric.data.Data, optional) – Graph data object; used for human-readable feature names.

  • ignore_cuda (bool, optional (default=False)) – Force the explainer to run on CPU even if CUDA is available.

Example

>>> explainer = CounterfactualExplainer(model, data)
>>> # Single observation
>>> df = explainer.explain(x, target_value=0.8, target_idx=0, max_iter=500)
>>> # Multiple observations (same perturbation applied to all)
>>> df = explainer.explain(x_batch, target_value=0.8, target_idx=0, max_iter=500)
>>> df.sort_values('perturbation', key=abs, ascending=False).head()
feature    original  perturbation  counterfactual
in0        0.12      0.45          0.57
in1        0.89     -0.23          0.66
in2        0.34      0.11          0.45
__init__(model: torch.nn.Module, data=None, ignore_cuda: bool = False) None[source]

Methods

__init__(model[, data, ignore_cuda])

explain(x, target_value[, target_idx, ...])

Learn minimal perturbation to achieve target model output.

explain(x: torch.Tensor, target_value: Union[float, torch.Tensor], target_idx: Optional[Union[int, List[int]]] = None, trainable_mask: Optional[torch.Tensor] = None, lr: float = 0.01, weight_decay: float = 0.01, dropout: float = 0.0, min_iter: int = 25, max_iter: int = 1000, tolerance: float = 1e-05, verbose: bool = True, transform: Optional[Callable] = torch.nn.Identity) pandas.DataFrame[source]

Learn minimal perturbation to achieve target model output.

Parameters:
  • x (torch.Tensor (shape: [N_in] or [B, N_in])) – Input feature tensor. If 1D, it will be unsqueezed to batch size 1. For multiple observations, the same perturbation will be applied to all.

  • target_value (float or torch.Tensor) – Desired model output. If target_idx is specified, this should be a scalar or tensor matching the number of target indices. If target_idx is None, this should match the full output dimension. The same target value is used for all observations in the batch.

  • target_idx (int, list[int], or None) – Output dimension(s) to target. If None, targets all outputs.

  • trainable_mask (torch.Tensor, optional (shape: [N_in])) – Boolean mask specifying which features can be perturbed. If None, all features are trainable.

  • lr (float, optional (default=0.01)) – Learning rate for gradient descent.

  • weight_decay (float, optional (default=0.01)) – L2 regularization coefficient for minimizing perturbation magnitude.

  • dropout (float, optional (default=0.0)) – Dropout rate for the model.

  • min_iter (int, optional (default=25)) – Minimum number of optimization iterations.

  • max_iter (int, optional (default=1000)) – Maximum number of optimization iterations.

  • tolerance (float, optional (default=1e-6)) – Convergence tolerance for loss change between iterations.

  • verbose (bool, optional (default=False)) – Print optimization progress.

  • transform (Callable, optional) – Transform the perturbation, must be differentiable. E.g., relu(), tanh()

Returns:

DataFrame with columns ‘feature’, ‘original’, ‘perturbation’, ‘counterfactual’ showing the learned perturbations for each input feature.

Return type:

pd.DataFrame