gsnn.gsnn.interpret.CounterfactualExplainer
- class gsnn.gsnn.interpret.CounterfactualExplainer(model: torch.nn.Module, data=None, ignore_cuda: bool = False)[source]
Bases:
objectFeature-level counterfactual explainer using gradient descent.
This module learns a minimal perturbation δ to an input x such that:
f(x + δ) ≈ target_value
The perturbation is learned via gradient descent with L2 regularization to enforce minimality. The optimization objective is:
\[\min_δ \|f(x + δ) - \text{target}\|^2 + λ\|δ\|^2\]where λ is the weight decay parameter controlling the trade-off between achieving the target and minimizing the perturbation.
δ_i > 0feature i needs to be increased to reach the target.δ_i < 0feature i needs to be decreased to reach the target.δ_i ≈ 0feature i is irrelevant for the counterfactual.
- Parameters:
model (torch.nn.Module) – Trained GSNN model (evaluation mode is enforced internally).
data (torch_geometric.data.Data, optional) – Graph data object; used for human-readable feature names.
ignore_cuda (bool, optional (default=False)) – Force the explainer to run on CPU even if CUDA is available.
Example
>>> explainer = CounterfactualExplainer(model, data) >>> # Single observation >>> df = explainer.explain(x, target_value=0.8, target_idx=0, max_iter=500) >>> # Multiple observations (same perturbation applied to all) >>> df = explainer.explain(x_batch, target_value=0.8, target_idx=0, max_iter=500) >>> df.sort_values('perturbation', key=abs, ascending=False).head() feature original perturbation counterfactual in0 0.12 0.45 0.57 in1 0.89 -0.23 0.66 in2 0.34 0.11 0.45
- __init__(model: torch.nn.Module, data=None, ignore_cuda: bool = False) None[source]
Methods
__init__(model[, data, ignore_cuda])explain(x, target_value[, target_idx, ...])Learn minimal perturbation to achieve target model output.
- explain(x: torch.Tensor, target_value: Union[float, torch.Tensor], target_idx: Optional[Union[int, List[int]]] = None, trainable_mask: Optional[torch.Tensor] = None, lr: float = 0.01, weight_decay: float = 0.01, dropout: float = 0.0, min_iter: int = 25, max_iter: int = 1000, tolerance: float = 1e-05, verbose: bool = True, transform: Optional[Callable] = torch.nn.Identity) pandas.DataFrame[source]
Learn minimal perturbation to achieve target model output.
- Parameters:
x (torch.Tensor (shape: [N_in] or [B, N_in])) – Input feature tensor. If 1D, it will be unsqueezed to batch size 1. For multiple observations, the same perturbation will be applied to all.
target_value (float or torch.Tensor) – Desired model output. If target_idx is specified, this should be a scalar or tensor matching the number of target indices. If target_idx is None, this should match the full output dimension. The same target value is used for all observations in the batch.
target_idx (int, list[int], or None) – Output dimension(s) to target. If None, targets all outputs.
trainable_mask (torch.Tensor, optional (shape: [N_in])) – Boolean mask specifying which features can be perturbed. If None, all features are trainable.
lr (float, optional (default=0.01)) – Learning rate for gradient descent.
weight_decay (float, optional (default=0.01)) – L2 regularization coefficient for minimizing perturbation magnitude.
dropout (float, optional (default=0.0)) – Dropout rate for the model.
min_iter (int, optional (default=25)) – Minimum number of optimization iterations.
max_iter (int, optional (default=1000)) – Maximum number of optimization iterations.
tolerance (float, optional (default=1e-6)) – Convergence tolerance for loss change between iterations.
verbose (bool, optional (default=False)) – Print optimization progress.
transform (Callable, optional) – Transform the perturbation, must be differentiable. E.g., relu(), tanh()
- Returns:
DataFrame with columns ‘feature’, ‘original’, ‘perturbation’, ‘counterfactual’ showing the learned perturbations for each input feature.
- Return type:
pd.DataFrame