import copy
from typing import Union, List, Optional
import numpy as np
import pandas as pd
import torch
from gsnn.interpret._kwargs_utils import (
concat_pair,
normalize_model_kwargs,
repeat_batch,
slice_per_sample,
)
[docs]class ContrastiveIGExplainer:
r"""Edge-level Integrated-Gradients explainer for *contrastive* questions.
This module attributes the prediction **difference**::
Δf = f(x1)[target_idx] - f(x2)[target_idx]
to every edge *e* in the graph by integrating along a straight-line **mask
path** *m(α)=α·1, α∈[0,1]* while keeping the two inputs ``x1`` and ``x2``
fixed. The attribution for an edge equals
.. math::
\mathrm{IG}_e = \int_0^1 \frac{\partial}{\partial m_e}
|f(x_1;m(α)) - f(x_2;m(α))|\,dα.
Interpretation
--------------
* ``IG_e > 0`` the presence of edge *e* increases |Δf|.
* ``IG_e < 0`` the presence of edge *e* decreases |Δf|.
* ``IG_e ≈ 0`` edge *e* is irrelevant to the gap.
By construction :math:`\sum_e \mathrm{IG}_e = |Δf|` (completeness).
Parameters
----------
model : torch.nn.Module
Trained GSNN model (evaluation mode is enforced internally).
data : torch_geometric.data.Data
Graph data object; only used for human-readable edge names.
n_steps : int, optional (default=50)
Number of interpolation points along the mask path (baseline included).
ignore_cuda : bool, optional (default=False)
Force the explainer to run on CPU even if CUDA is available.
Example
-------
>>> explainer = ContrastiveIGExplainer(model, data, n_steps=64)
>>> df = explainer.explain(x1, x2, target_idx=0)
>>> df.sort_values('score', ascending=False).head()
source target score
in0 func0 0.42
func0 func3 0.40
func3 out0 0.38
>>> # Compute IG for only a subset of edges
>>> edge_mask = np.array([True, False, True, False, True]) # Only integrate edges 0, 2, 4
>>> df = explainer.explain(x1, x2, target_idx=0, element_mask=edge_mask)
>>> # Edges 1 and 3 will have None scores; edges 0, 2, 4 have IG attributions
>>> # Note: Completeness axiom won't hold when using element_mask
"""
[docs] def __init__(
self,
model: torch.nn.Module,
data,
n_steps: int = 50,
ignore_cuda: bool = False,
) -> None:
self.data = data
self.n_steps = n_steps
self.device = (
"cuda" if (torch.cuda.is_available() and not ignore_cuda) else "cpu"
)
# Work on a frozen copy of the model to avoid side-effects.
model = copy.deepcopy(model).eval().to(self.device)
for p in model.parameters():
p.requires_grad = False
self.model = model
# Constant edge-mask variable (value = 1); gradients will flow through.
self.E = model.edge_index.size(1)
self._edge_mask = torch.ones((1, self.E), device=self.device, requires_grad=True)
[docs] def explain(
self,
x1: torch.Tensor,
x2: torch.Tensor,
target_idx: Union[int, List[int]],
*,
jitter: Optional[torch.Tensor] = None,
element_mask=None,
target: str = 'edge',
reduction: str = 'mean',
model_kwargs1=None,
model_kwargs2=None,
) -> pd.DataFrame:
"""Compute attributions for *f(x₁) − f(x₂)*.
Parameters
----------
x1, x2 : torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch)
Two input feature tensors. They must have identical batch size and
ordering of nodes. Each pair (x1[i], x2[i]) is explained.
target_idx : int or list[int]
Output dimension(s) to explain. If a list is provided the
attributions refer to the **sum** of those outputs.
jitter : torch.Tensor, optional
Optional noise to perturb the mask path.
element_mask : torch.Tensor or np.ndarray, optional (shape: [E] or [N])
Boolean mask indicating which elements to compute IG attributions for.
If None, all elements are integrated. If provided:
- True/nonzero elements: integrate from 0 to 1 (normal IG path)
- False/zero elements: fixed at 1 throughout the path (no integration)
Elements not in the mask will have None scores in the output.
Note: When using element_mask, the completeness axiom (attributions sum
to |Δf|) will not hold since only a subset of elements are integrated.
The attributions measure "contribution to |Δf| while holding other
elements fixed at full strength".
target : str, optional (default='edge')
Whether to return 'edge' or 'node' level attributions.
reduction : str, optional (default='mean')
How to aggregate attributions across batch samples:
- 'mean': average attributions across samples (default)
- 'sum': sum attributions across samples
- 'none': return all per-sample attributions (adds 'sample_idx' column)
Returns
-------
pd.DataFrame
If target='edge': columns ['source', 'target', 'score'] for edge attributions.
If target='node': columns ['node', 'score'] for node attributions.
If reduction='none': additional 'sample_idx' column for batch dimension.
Elements not in element_mask will have None scores.
"""
if target not in ['edge', 'node']:
raise ValueError(f"target must be 'edge' or 'node', got '{target}'")
if reduction not in ['mean', 'sum', 'none']:
raise ValueError(f"reduction must be 'mean', 'sum', or 'none', got '{reduction}'")
if target == 'edge':
return self._compute_edge_attributions(x1, x2, target_idx, jitter, element_mask, reduction,
model_kwargs1=model_kwargs1, model_kwargs2=model_kwargs2)
else:
return self._compute_node_attributions(x1, x2, target_idx, jitter, element_mask, reduction,
model_kwargs1=model_kwargs1, model_kwargs2=model_kwargs2)
def _compute_edge_attributions(
self,
x1: torch.Tensor,
x2: torch.Tensor,
target_idx: Union[int, List[int]],
jitter: Optional[torch.Tensor] = None,
element_mask=None,
reduction: str = 'mean',
model_kwargs1=None,
model_kwargs2=None,
) -> pd.DataFrame:
"""Compute edge-level attributions for *f(x₁) − f(x₂)*."""
mk1 = normalize_model_kwargs(model_kwargs1)
mk2 = normalize_model_kwargs(model_kwargs2)
x1, x2 = x1.to(self.device), x2.to(self.device)
# Ensure batch dimension
if x1.dim() == 1:
x1 = x1.unsqueeze(0)
if x2.dim() == 1:
x2 = x2.unsqueeze(0)
B = x1.size(0) # batch size
T = self.n_steps + 1 # number of points along the path (baseline included)
if isinstance(target_idx, int):
target_idx = [target_idx]
# -------------------------------------------------------------
# Process element_mask
# -------------------------------------------------------------
if element_mask is not None:
if isinstance(element_mask, np.ndarray):
element_mask = torch.from_numpy(element_mask)
element_mask = element_mask.to(self.device).bool()
mask_float = element_mask.float().unsqueeze(0) # (1, E)
else:
element_mask = None
mask_float = None
# -------------------------------------------------------------
# Build base mask (optionally perturbed with jitter)
# -------------------------------------------------------------
base_mask = torch.ones((1, self.E), device=self.device)
if jitter is not None:
jitter = jitter.to(self.device)
if jitter.dim() == 1:
jitter = jitter.unsqueeze(0)
base_mask = torch.clamp(base_mask * (1.0 + jitter), 0.0, 1.0)
alphas = torch.linspace(0.0, 1.0, T, device=self.device).view(-1, 1) # (T,1)
# -------------------------------------------------------------
# Build interpolated mask path
# -------------------------------------------------------------
# Standard interpolation from 0 to base_mask
interpolated = alphas * base_mask # (T, E)
if mask_float is not None:
# For masked edges (True): use interpolated values (0 -> base_mask)
# For unmasked edges (False): fix at 1.0 throughout
mask_path_template = mask_float * interpolated + (1.0 - mask_float) * 1.0
else:
mask_path_template = interpolated
# ------------------------------------------------------------------
# Process each sample pair
# ------------------------------------------------------------------
all_ig = []
for sample_idx in range(B):
x1i = x1[sample_idx:sample_idx+1] # (1, N_in)
x2i = x2[sample_idx:sample_idx+1] # (1, N_in)
# Need fresh tensor for gradient computation
mask_path = mask_path_template.clone().requires_grad_(True)
# Prepare feature batches replicated along path
x1_batch = x1i.repeat(T, 1) # (T , N_in)
x2_batch = x2i.repeat(T, 1) # (T , N_in)
# Concatenate for single forward pass
x_batch = torch.cat([x1_batch, x2_batch], dim=0) # (2T , N_in)
mask_batch = mask_path.repeat(2, 1) # (2T , E)
# Per-side model_kwargs (e.g. x_fn): pick this sample's row, repeat
# across T path points, then concat to match x_batch (2T entries).
mk1_i_T = repeat_batch(slice_per_sample(mk1, sample_idx), T)
mk2_i_T = repeat_batch(slice_per_sample(mk2, sample_idx), T)
mk_joint = concat_pair(mk1_i_T, mk2_i_T)
# Forward pass
preds = self.model(x_batch, edge_mask=mask_batch, **mk_joint)[:, target_idx] # (2T , |T|)
preds = preds.sum(dim=1) # (2T ,)
preds_x1 = preds[:T]
preds_x2 = preds[T:]
diff_abs = (preds_x1 - preds_x2).abs() # (T ,)
# Back-propagate through mask_path
grads = torch.autograd.grad(diff_abs.sum(), mask_path)[0] # (T , E)
# Trapezoidal rule over the path
trap = (grads[:-1] + grads[1:]) / 2.0 # (T-1 , E)
avg_grad = trap.mean(dim=0) # (E ,)
ig_scores = avg_grad * base_mask.squeeze(0) # multiply by Δmask (base_mask - 0)
# Set unmasked edges to NaN
if element_mask is not None:
ig_scores = torch.where(element_mask, ig_scores, torch.tensor(float('nan'), device=self.device))
all_ig.append(ig_scores)
all_ig = torch.stack(all_ig, dim=0) # (B, E)
# ------------------------------------------------------------------
# Package results with reduction
# ------------------------------------------------------------------
src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()]
if reduction == 'none':
dfs = []
for i in range(B):
scores = all_ig[i].detach().cpu().numpy()
scores = [None if np.isnan(score) else score for score in scores]
df = pd.DataFrame({
"sample_idx": i,
"source": src,
"target": dst,
"score": scores,
})
dfs.append(df)
return pd.concat(dfs, ignore_index=True)
elif reduction == 'sum':
scores_agg = torch.nansum(all_ig, dim=0) if element_mask is not None else all_ig.sum(dim=0)
else: # mean
scores_agg = torch.nanmean(all_ig, dim=0) if element_mask is not None else all_ig.mean(dim=0)
# Convert NaN to None for edges not in mask
scores = scores_agg.detach().cpu().numpy()
if element_mask is not None:
scores = [None if np.isnan(score) else score for score in scores]
return pd.DataFrame({
"source": src,
"target": dst,
"score": scores,
})
def _compute_node_attributions(
self,
x1: torch.Tensor,
x2: torch.Tensor,
target_idx: Union[int, List[int]],
jitter: Optional[torch.Tensor] = None,
element_mask=None,
reduction: str = 'mean',
model_kwargs1=None,
model_kwargs2=None,
) -> pd.DataFrame:
"""Compute node-level attributions for *f(x₁) − f(x₂)*."""
mk1 = normalize_model_kwargs(model_kwargs1)
mk2 = normalize_model_kwargs(model_kwargs2)
x1, x2 = x1.to(self.device), x2.to(self.device)
# Ensure batch dimension
if x1.dim() == 1:
x1 = x1.unsqueeze(0)
if x2.dim() == 1:
x2 = x2.unsqueeze(0)
B = x1.size(0) # batch size
T = self.n_steps + 1 # number of points along the path (baseline included)
N = self.model.num_nodes
if isinstance(target_idx, int):
target_idx = [target_idx]
# -------------------------------------------------------------
# Process element_mask
# -------------------------------------------------------------
if element_mask is not None:
if isinstance(element_mask, np.ndarray):
element_mask = torch.from_numpy(element_mask)
element_mask = element_mask.to(self.device).bool()
mask_float = element_mask.float().unsqueeze(0) # (1, N)
else:
element_mask = None
mask_float = None
# -------------------------------------------------------------
# Build base mask (optionally perturbed with jitter)
# -------------------------------------------------------------
base_mask = torch.ones((1, N), device=self.device)
if jitter is not None:
jitter = jitter.to(self.device)
if jitter.dim() == 1:
jitter = jitter.unsqueeze(0)
base_mask = torch.clamp(base_mask * (1.0 + jitter), 0.0, 1.0)
alphas = torch.linspace(0.0, 1.0, T, device=self.device).view(-1, 1) # (T,1)
# -------------------------------------------------------------
# Build interpolated mask path
# -------------------------------------------------------------
# Standard interpolation from 0 to base_mask
interpolated = alphas * base_mask # (T, N)
if mask_float is not None:
# For masked nodes (True): use interpolated values (0 -> base_mask)
# For unmasked nodes (False): fix at 1.0 throughout
mask_path_template = mask_float * interpolated + (1.0 - mask_float) * 1.0
else:
mask_path_template = interpolated
# ------------------------------------------------------------------
# Process each sample pair
# ------------------------------------------------------------------
all_ig = []
for sample_idx in range(B):
x1i = x1[sample_idx:sample_idx+1] # (1, N_in)
x2i = x2[sample_idx:sample_idx+1] # (1, N_in)
# Need fresh tensor for gradient computation
mask_path = mask_path_template.clone().requires_grad_(True)
# Prepare feature batches replicated along path
x1_batch = x1i.repeat(T, 1) # (T , N_in)
x2_batch = x2i.repeat(T, 1) # (T , N_in)
# Concatenate for single forward pass
x_batch = torch.cat([x1_batch, x2_batch], dim=0) # (2T , N_in)
mask_batch = mask_path.repeat(2, 1) # (2T , N)
# Per-side model_kwargs (e.g. x_fn): pick this sample's row, repeat
# across T path points, then concat to match x_batch (2T entries).
mk1_i_T = repeat_batch(slice_per_sample(mk1, sample_idx), T)
mk2_i_T = repeat_batch(slice_per_sample(mk2, sample_idx), T)
mk_joint = concat_pair(mk1_i_T, mk2_i_T)
# Forward pass
preds = self.model(x_batch, node_mask=mask_batch, **mk_joint)[:, target_idx] # (2T , |T|)
preds = preds.sum(dim=1) # (2T ,)
preds_x1 = preds[:T]
preds_x2 = preds[T:]
diff_abs = (preds_x1 - preds_x2).abs() # (T ,)
# Back-propagate through mask_path
grads = torch.autograd.grad(diff_abs.sum(), mask_path)[0] # (T , N)
# Trapezoidal rule over the path
trap = (grads[:-1] + grads[1:]) / 2.0 # (T-1 , N)
avg_grad = trap.mean(dim=0) # (N ,)
ig_scores = avg_grad * base_mask.squeeze(0) # multiply by Δmask (base_mask - 0)
# Set unmasked nodes to NaN
if element_mask is not None:
ig_scores = torch.where(element_mask, ig_scores, torch.tensor(float('nan'), device=self.device))
all_ig.append(ig_scores)
all_ig = torch.stack(all_ig, dim=0) # (B, N)
# ------------------------------------------------------------------
# Package results with reduction
# ------------------------------------------------------------------
node_names = np.array(self.model.homo_names)
if reduction == 'none':
dfs = []
for i in range(B):
scores = all_ig[i].detach().cpu().numpy()
scores = [None if np.isnan(score) else score for score in scores]
df = pd.DataFrame({
"sample_idx": i,
"node": node_names,
"score": scores,
})
dfs.append(df)
return pd.concat(dfs, ignore_index=True)
elif reduction == 'sum':
scores_agg = torch.nansum(all_ig, dim=0) if element_mask is not None else all_ig.sum(dim=0)
else: # mean
scores_agg = torch.nanmean(all_ig, dim=0) if element_mask is not None else all_ig.mean(dim=0)
# Convert NaN to None for nodes not in mask
scores = scores_agg.detach().cpu().numpy()
if element_mask is not None:
scores = [None if np.isnan(score) else score for score in scores]
return pd.DataFrame({
"node": node_names,
"score": scores,
})