Source code for gsnn.interpret.IGExplainer

import numpy as np
import torch
import copy
import pandas as pd
from typing import Optional

from gsnn.interpret._kwargs_utils import (
    normalize_model_kwargs,
    repeat_batch,
    slice_per_sample,
)


[docs]class IGExplainer: r"""Integrated-Gradients explainer for GSNN models (non-contrastive). Computes per-edge or per-node attributions for a prediction *f(x)[target_idx]* by integrating the gradient along a straight-line path in **feature space** from a baseline input *x′* (default all zeros) to the observation *x*. For edge-level attributions:: IG_e = (x - x′) · \int_0^1 ∂f(x′ + α(x-x′))/∂m_e dα. For node-level attributions:: IG_n = (x - x′) · \int_0^1 ∂f(x′ + α(x-x′))/∂n_n dα. When the baseline masks are zero this reduces to the EdgeIG/NodeIG variants. The attributions satisfy the completeness axiom for their respective domains. Node-level and edge-level attributions are computed independently using separate masking mechanisms in the GSNN model. Parameters ---------- model : torch.nn.Module Trained GSNN model (copied and frozen internally). data : torch_geometric.data.Data Graph data object; only used for edge names. ignore_cuda : bool, optional (default=False) Force the explainer to run on CPU even if CUDA is available. n_steps : int, optional (default=50) Number of points on the IG path (baseline included). baseline : torch.Tensor or None, optional Custom baseline edge-mask of shape ``(1,E)``. ``None`` defaults to an all-zeros mask. Example ------- >>> explainer = IGExplainer(model, data, n_steps=64) >>> # Edge-level attributions >>> df_edge = explainer.explain(x, target_idx=0, target='edge') >>> df_edge.nlargest(5, 'score') >>> # Node-level attributions >>> df_node = explainer.explain(x, target_idx=0, target='node') >>> df_node.nlargest(5, 'score') >>> # Compute IG for only a subset of edges >>> edge_mask = np.array([True, False, True, False, True]) # Only integrate edges 0, 2, 4 >>> df_edge = explainer.explain(x, target_idx=0, target='edge', element_mask=edge_mask) >>> # Edges 1 and 3 will have None scores; edges 0, 2, 4 have IG attributions >>> # Note: Completeness axiom won't hold when using element_mask """
[docs] def __init__(self, model, data, ignore_cuda=False, n_steps=50, baseline=None): """Create a new IGExplainer instance.""" self.data = data self.device = 'cuda' if (torch.cuda.is_available() and not ignore_cuda) else 'cpu' model = copy.deepcopy(model) model = model.eval() model = model.to(self.device) self.model = model self.n_steps = n_steps self.E = model.edge_index.size(1) # Baseline mask (shape: 1 x E). Default is all-zeros if none provided. self.baseline = torch.zeros((1, self.E), device=self.device) if baseline is None else baseline.to(self.device)
[docs] def explain(self, x, target_idx, *, jitter: Optional[torch.Tensor] = None, element_mask=None, target='edge', reduction='mean', model_kwargs=None): ''' Compute integrated gradients attributions for GSNN predictions. Parameters ---------- x : torch.Tensor Input features of shape (N_in,), (1, N_in), or (B, N_in) for batch. target_idx : int Index of the target output node to explain. jitter : torch.Tensor, optional Optional noise to add to baseline, shape (E,) or (1, E) for edge target, shape (N,) or (1, N) for node target. element_mask : torch.Tensor or np.ndarray, optional (shape: [E] or [N]) Boolean mask indicating which elements to compute IG attributions for. If None, all elements are integrated. If provided: - True/nonzero elements: integrate from baseline to 1 (normal IG) - False/zero elements: fixed at 1 throughout the path (no integration) Elements not in the mask will have None scores in the output. Note: When using element_mask, the completeness axiom (attributions sum to f(x) - f(baseline)) will not hold since only a subset of elements are integrated. The attributions measure "contribution while holding other elements fixed at full strength". target : str, optional (default='edge') Whether to return 'edge' or 'node' level attributions. reduction : str, optional (default='mean') How to aggregate attributions across batch samples: - 'mean': average attributions across samples (default) - 'sum': sum attributions across samples - 'none': return all per-sample attributions (adds 'sample_idx' column) model_kwargs : dict, optional (default=None) Extra keyword arguments forwarded to every ``self.model(...)`` call (e.g. ``{'x_fn': x_fn}`` for models trained with ``node_activity=True``). Tensor values must have leading dim equal to ``x.shape[0]`` (or 1 to broadcast); they will be sliced per sample and replicated to ``n_steps+1`` along the IG path. ``edge_mask`` / ``node_mask`` are reserved and should not be included. Returns ------- pd.DataFrame If target='edge': columns ['source', 'target', 'score'] for edge attributions. If target='node': columns ['node', 'score'] for node attributions. If reduction='none': additional 'sample_idx' column for batch dimension. Elements not in element_mask will have None scores. Following approach and style from: https://github.com/ankurtaly/Integrated-Gradients/blob/master/IntegratedGradients/integrated_gradients.py Reference: @article{DBLP:journals/corr/SundararajanTY17, author = {Mukund Sundararajan and Ankur Taly and Qiqi Yan}, title = {Axiomatic Attribution for Deep Networks}, journal = {CoRR}, volume = {abs/1703.01365}, year = {2017}, url = {http://arxiv.org/abs/1703.01365}, eprinttype = {arXiv}, eprint = {1703.01365}, timestamp = {Mon, 13 Aug 2018 16:48:32 +0200}, biburl = {https://dblp.org/rec/journals/corr/SundararajanTY17.bib}, bibsource = {dblp computer science bibliography, https://dblp.org} } ''' if target not in ['edge', 'node']: raise ValueError(f"target must be 'edge' or 'node', got '{target}'") if reduction not in ['mean', 'sum', 'none']: raise ValueError(f"reduction must be 'mean', 'sum', or 'none', got '{reduction}'") if target == 'edge': return self._compute_edge_attributions(x, target_idx, jitter, element_mask, reduction, model_kwargs=model_kwargs) else: return self._compute_node_attributions(x, target_idx, jitter, element_mask, reduction, model_kwargs=model_kwargs)
def _compute_edge_attributions(self, x, target_idx, jitter=None, element_mask=None, reduction='mean', model_kwargs=None): ''' Compute edge-level attributions using integrated gradients on edge_mask. Parameters ---------- x : torch.Tensor Input features of shape (N_in,), (1, N_in), or (B, N_in). target_idx : int Index of the target output node to explain. jitter : torch.Tensor, optional Optional noise to add to baseline, shape (E,) or (1, E). element_mask : torch.Tensor or np.ndarray, optional Boolean mask indicating which edges to compute IG for. reduction : str How to aggregate across batch: 'mean', 'sum', or 'none'. Returns ------- pd.DataFrame Columns ['source', 'target', 'score'] for edge attributions. If reduction='none': additional 'sample_idx' column. ''' model_kwargs = normalize_model_kwargs(model_kwargs) x = x.to(self.device) if x.dim() == 1: x = x.unsqueeze(0) # (1, N_in) elif x.dim() == 3: x = x.squeeze() if x.dim() == 3: raise ValueError("Expected (1, N_in) or (B, N_in) or (B, 1, N_in), got (B, D, N_in)") B = x.size(0) # batch size # ------------------------------------------------------------- # 0. Process element_mask # ------------------------------------------------------------- if element_mask is not None: if isinstance(element_mask, np.ndarray): element_mask = torch.from_numpy(element_mask) element_mask = element_mask.to(self.device).bool() mask_float = element_mask.float().unsqueeze(0) # (1, E) else: element_mask = None mask_float = None # ------------------------------------------------------------- # 1. Optionally perturb the baseline with *jitter* # ------------------------------------------------------------- if jitter is not None: jitter = jitter.to(self.device) if jitter.dim() == 1: jitter = jitter.unsqueeze(0) # make shape (1,E) baseline_ = torch.clamp(self.baseline + jitter, 0.0, 1.0) else: baseline_ = self.baseline # alphas: 0 … 1 (inclusive). We include both baseline (0) and full input (1). alphas = torch.linspace(0.0, 1.0, self.n_steps + 1, device=self.device).view(-1, 1) # ------------------------------------------------------------- # 2. Build interpolated edge-masks # ------------------------------------------------------------- # Standard interpolation from baseline to 1 interpolated = baseline_ + alphas * (1.0 - baseline_) # (n_steps+1 , E) if mask_float is not None: # For masked edges (True): use interpolated values # For unmasked edges (False): fix at 1.0 edge_masks_template = mask_float * interpolated + (1.0 - mask_float) * 1.0 else: edge_masks_template = interpolated # Process each sample and collect IG scores all_ig = [] for i in range(B): xi = x[i:i+1] # (1, N_in) # Need fresh tensor for gradient computation edge_masks = edge_masks_template.clone().requires_grad_(True) x_batch = xi.repeat(self.n_steps + 1, 1) # (n_steps+1 , N_in) # Replicate any per-sample model_kwargs (e.g. x_fn) across the # n_steps+1 IG path so the model sees one entry per replicated x. mk_i = repeat_batch(slice_per_sample(model_kwargs, i), self.n_steps + 1) preds = self.model(x_batch, edge_mask=edge_masks, **mk_i)[:, target_idx] # (n_steps+1,) # d(pred)/d(edge_mask) grads = torch.autograd.grad(preds.sum(), edge_masks)[0] # (n_steps+1 , E) # Trapezoidal rule approximation of the path integral trap_grads = (grads[:-1] + grads[1:]) / 2.0 # (n_steps , E) avg_grads = trap_grads.mean(dim=0) # (E,) # Compute IG: gradient * (endpoint - startpoint) # For unmasked edges, both endpoints are 1, so delta is 0 ig = avg_grads * (1. - baseline_.squeeze(0)) # (E,) # Set unmasked edges to NaN if element_mask is not None: ig = torch.where(element_mask, ig, torch.tensor(float('nan'), device=self.device)) all_ig.append(ig) all_ig = torch.stack(all_ig, dim=0) # (B, E) # Apply reduction src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()] if reduction == 'none': # Return per-sample attributions dfs = [] for i in range(B): scores = all_ig[i].detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] df = pd.DataFrame({ 'sample_idx': i, 'source': src, 'target': dst, 'score': scores }) dfs.append(df) return pd.concat(dfs, ignore_index=True) elif reduction == 'sum': ig_agg = torch.nansum(all_ig, dim=0) if element_mask is not None else all_ig.sum(dim=0) else: # mean ig_agg = torch.nanmean(all_ig, dim=0) if element_mask is not None else all_ig.mean(dim=0) # Convert NaN to None for edges not in mask scores = ig_agg.detach().cpu().numpy() if element_mask is not None: scores = [None if np.isnan(score) else score for score in scores] return pd.DataFrame({ 'source': src, 'target': dst, 'score': scores }) def _compute_node_attributions(self, x, target_idx, jitter=None, element_mask=None, reduction='mean', model_kwargs=None): ''' Compute node-level attributions using integrated gradients on node_mask. Parameters ---------- x : torch.Tensor Input features of shape (N_in,), (1, N_in), or (B, N_in). target_idx : int Index of the target output node to explain. jitter : torch.Tensor, optional Optional noise to add to baseline, shape (N,) or (1, N). element_mask : torch.Tensor or np.ndarray, optional Boolean mask indicating which nodes to compute IG for. reduction : str How to aggregate across batch: 'mean', 'sum', or 'none'. Returns ------- pd.DataFrame Columns ['node', 'score'] for node attributions. If reduction='none': additional 'sample_idx' column. ''' model_kwargs = normalize_model_kwargs(model_kwargs) x = x.to(self.device) if x.dim() == 1: x = x.unsqueeze(0) # (1, N_in) B = x.size(0) # batch size N = self.model.num_nodes # ------------------------------------------------------------- # 0. Process element_mask # ------------------------------------------------------------- if element_mask is not None: if isinstance(element_mask, np.ndarray): element_mask = torch.from_numpy(element_mask) element_mask = element_mask.to(self.device).bool() mask_float = element_mask.float().unsqueeze(0) # (1, N) else: element_mask = None mask_float = None # Baseline node mask (shape: 1 x N). Default is all-zeros (no nodes active) if none provided. baseline_node = torch.zeros((1, N), device=self.device) # ------------------------------------------------------------- # 1. Optionally perturb the baseline with *jitter* # ------------------------------------------------------------- if jitter is not None: jitter = jitter.to(self.device) if jitter.dim() == 1: jitter = jitter.unsqueeze(0) # make shape (1,N) baseline_node = torch.clamp(baseline_node + jitter, 0.0, 1.0) # alphas: 0 … 1 (inclusive). We include both baseline (0) and full input (1). alphas = torch.linspace(0.0, 1.0, self.n_steps + 1, device=self.device).view(-1, 1) # ------------------------------------------------------------- # 2. Build interpolated node-masks # ------------------------------------------------------------- # Full node mask (all nodes active) as the target full_node_mask = torch.ones((1, N), device=self.device) # Standard interpolation from baseline to 1 interpolated = baseline_node + alphas * (full_node_mask - baseline_node) # (n_steps+1 , N) if mask_float is not None: # For masked nodes (True): use interpolated values # For unmasked nodes (False): fix at 1.0 node_masks_template = mask_float * interpolated + (1.0 - mask_float) * 1.0 else: node_masks_template = interpolated # Process each sample and collect IG scores all_ig = [] for i in range(B): xi = x[i:i+1] # (1, N_in) # Need fresh tensor for gradient computation node_masks = node_masks_template.clone().requires_grad_(True) x_batch = xi.repeat(self.n_steps + 1, 1) # (n_steps+1 , N_in) mk_i = repeat_batch(slice_per_sample(model_kwargs, i), self.n_steps + 1) preds = self.model(x_batch, node_mask=node_masks, **mk_i)[:, target_idx] # (n_steps+1,) # d(pred)/d(node_mask) grads = torch.autograd.grad(preds.sum(), node_masks)[0] # (n_steps+1 , N) # Trapezoidal rule approximation of the path integral trap_grads = (grads[:-1] + grads[1:]) / 2.0 # (n_steps , N) avg_grads = trap_grads.mean(dim=0) # (N,) # Compute IG: gradient * (endpoint - startpoint) # For unmasked nodes, both endpoints are 1, so delta is 0 ig = avg_grads * (full_node_mask.squeeze(0) - baseline_node.squeeze(0)) # (N,) # Set unmasked nodes to NaN if element_mask is not None: ig = torch.where(element_mask, ig, torch.tensor(float('nan'), device=self.device)) all_ig.append(ig) all_ig = torch.stack(all_ig, dim=0) # (B, N) # Apply reduction node_names = np.array(self.model.homo_names) if reduction == 'none': # Return per-sample attributions dfs = [] for i in range(B): scores = all_ig[i].detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] df = pd.DataFrame({ 'sample_idx': i, 'node': node_names, 'score': scores }) dfs.append(df) return pd.concat(dfs, ignore_index=True) elif reduction == 'sum': ig_agg = torch.nansum(all_ig, dim=0) if element_mask is not None else all_ig.sum(dim=0) else: # mean ig_agg = torch.nanmean(all_ig, dim=0) if element_mask is not None else all_ig.mean(dim=0) # Convert NaN to None for nodes not in mask scores = ig_agg.detach().cpu().numpy() if element_mask is not None: scores = [None if np.isnan(score) else score for score in scores] return pd.DataFrame({ 'node': node_names, 'score': scores })