Source code for gsnn.gsnn.interpret.ContrastiveOcclusionExplainer

import copy
from typing import Union, List

import numpy as np
import pandas as pd
import torch


[docs]class ContrastiveOcclusionExplainer: r"""Simple batched edge occlusion explainer for *contrastive* questions. This module attributes the prediction **difference**:: Δf = f(x1)[target_idx] - f(x2)[target_idx] to every edge *e* by systematically removing each edge and measuring the change in the absolute prediction difference: .. math:: \mathrm{Occ}_e = |Δf_{\text{baseline}}| - |Δf_{\text{without } e}| where the baseline uses all edges present and the occluded version removes edge *e* completely (edge_mask = 0). Interpretation -------------- * ``Occ_e > 0`` removing edge *e* decreases |Δf| (edge contributes to difference). * ``Occ_e < 0`` removing edge *e* increases |Δf| (edge reduces difference). * ``Occ_e ≈ 0`` edge *e* has no impact on the prediction difference. Parameters ---------- model : torch.nn.Module Trained GSNN model (evaluation mode is enforced internally). data : torch_geometric.data.Data Graph data object; only used for human-readable edge names. batch_size : int, optional (default=32) Number of edge occlusions to process in parallel. ignore_cuda : bool, optional (default=False) Force the explainer to run on CPU even if CUDA is available. verbose : bool, optional (default=False) Print progress information during explanation computation. Example ------- >>> explainer = ContrastiveOcclusionExplainer(model, data, batch_size=64) >>> df = explainer.explain(x1, x2, target_idx=0) >>> df.sort_values('score', ascending=False).head() source target score in0 func0 0.42 func0 func3 0.40 func3 out0 0.38 """
[docs] def __init__( self, model: torch.nn.Module, data, batch_size: int = 32, ignore_cuda: bool = False, verbose: bool = False, ) -> None: self.data = data self.batch_size = batch_size self.verbose = verbose self.device = ( "cuda" if (torch.cuda.is_available() and not ignore_cuda) else "cpu" ) # Work on a frozen copy of the model to avoid side-effects. model = copy.deepcopy(model).eval().to(self.device) for p in model.parameters(): p.requires_grad = False self.model = model # Store number of edges self.E = model.edge_index.size(1) if self.verbose: print(f"ContrastiveOcclusionExplainer initialized:") print(f" Device: {self.device}") print(f" Batch size: {self.batch_size}") print(f" Total edges: {self.E}")
[docs] def explain( self, x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], *, element_mask=None, target: str = 'edge', reduction: str = 'mean', ) -> pd.DataFrame: """Compute occlusion attributions for *f(x₁) − f(x₂)*. Parameters ---------- x1, x2 : torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch) Two input feature tensors. They must have identical batch size and ordering of nodes. Each pair (x1[i], x2[i]) is explained. target_idx : int or list[int] Output dimension(s) to explain. If a list is provided the attributions refer to the **sum** of those outputs. element_mask : torch.Tensor or np.ndarray, optional (shape: [E] or [N]) Boolean mask indicating which elements to compute occlusion for. If None, all elements are considered. If provided, only elements where element_mask[i] is True will have occlusion scores computed. target : str, optional (default='edge') Whether to return 'edge' or 'node' level attributions. reduction : str, optional (default='mean') How to aggregate attributions across batch samples: - 'mean': average attributions across samples (default) - 'sum': sum attributions across samples - 'none': return all per-sample attributions (adds 'sample_idx' column) Returns ------- pd.DataFrame If target='edge': columns ['source', 'target', 'score'] for edge attributions. If target='node': columns ['node', 'score'] for node attributions. If reduction='none': additional 'sample_idx' column for batch dimension. Elements not in element_mask will have None scores. """ if target not in ['edge', 'node']: raise ValueError(f"target must be 'edge' or 'node', got '{target}'") if reduction not in ['mean', 'sum', 'none']: raise ValueError(f"reduction must be 'mean', 'sum', or 'none', got '{reduction}'") if target == 'edge': return self._compute_edge_attributions(x1, x2, target_idx, element_mask, reduction) else: return self._compute_node_attributions(x1, x2, target_idx, element_mask, reduction)
def _compute_edge_attributions( self, x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], element_mask=None, reduction: str = 'mean', ) -> pd.DataFrame: """Compute edge-level occlusion attributions for *f(x₁) − f(x₂)*.""" x1, x2 = x1.to(self.device), x2.to(self.device) # Ensure batch dimension if x1.dim() == 1: x1 = x1.unsqueeze(0) if x2.dim() == 1: x2 = x2.unsqueeze(0) B = x1.size(0) # batch size if isinstance(target_idx, int): target_idx = [target_idx] # ------------------------------------------------------------------ # Process element_mask # ------------------------------------------------------------------ if element_mask is not None: if isinstance(element_mask, np.ndarray): element_mask = torch.from_numpy(element_mask) element_mask = element_mask.to(self.device).bool() edges_to_occlude = torch.where(element_mask)[0] else: edges_to_occlude = torch.arange(self.E, device=self.device) num_edges_to_occlude = len(edges_to_occlude) if self.verbose: print(f"\nStarting contrastive edge occlusion explanation:") print(f" Input shapes: x1={x1.shape}, x2={x2.shape}") print(f" Batch size: {B}") print(f" Target indices: {target_idx}") print(f" Edges to occlude: {num_edges_to_occlude} / {self.E}") print(f" Total edge batches to process: {((num_edges_to_occlude - 1) // self.batch_size) + 1 if num_edges_to_occlude > 0 else 0}") # ------------------------------------------------------------------ # Process each sample pair # ------------------------------------------------------------------ all_scores = [] for sample_idx in range(B): x1i = x1[sample_idx:sample_idx+1] # (1, N_in) x2i = x2[sample_idx:sample_idx+1] # (1, N_in) if self.verbose and B > 1: print(f"Processing sample {sample_idx + 1}/{B}") # Compute baseline difference (all edges present) baseline_mask = torch.ones((1, self.E), device=self.device) baseline_diff = self._compute_diff_edge(x1i, x2i, target_idx, baseline_mask) if self.verbose: print(f" Baseline |Δf| = {baseline_diff:.6f}") # Compute occlusion scores in batches occlusion_scores = torch.full((self.E,), float('nan'), device=self.device) if num_edges_to_occlude > 0: for start_idx in range(0, num_edges_to_occlude, self.batch_size): end_idx = min(start_idx + self.batch_size, num_edges_to_occlude) batch_size_actual = end_idx - start_idx if self.verbose: print(f"Processing edge batch {start_idx // self.batch_size + 1}/{((num_edges_to_occlude - 1) // self.batch_size) + 1}", end='\r') # Get the actual edge indices for this batch batch_edge_indices = edges_to_occlude[start_idx:end_idx] # Create batch of masks with one edge removed per mask batch_masks = torch.ones((batch_size_actual, self.E), device=self.device) for i, edge_idx in enumerate(batch_edge_indices): batch_masks[i, edge_idx] = 0.0 # Replicate inputs for batch processing x1_batch = x1i.repeat(batch_size_actual, 1) x2_batch = x2i.repeat(batch_size_actual, 1) x_batch = torch.cat([x1_batch, x2_batch], dim=0) mask_batch = batch_masks.repeat(2, 1) # Forward pass preds = self.model(x_batch, edge_mask=mask_batch)[:, target_idx].sum(dim=1) preds_x1 = preds[:batch_size_actual] preds_x2 = preds[batch_size_actual:] # Compute occlusion effects occluded_diffs = (preds_x1 - preds_x2).abs() batch_scores = baseline_diff - occluded_diffs occlusion_scores[batch_edge_indices] = batch_scores all_scores.append(occlusion_scores) if self.verbose: print("\nEdge occlusion explanation complete.") all_scores = torch.stack(all_scores, dim=0) # (B, E) # ------------------------------------------------------------------ # Package results with reduction # ------------------------------------------------------------------ src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()] if reduction == 'none': dfs = [] for i in range(B): scores = all_scores[i].detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] df = pd.DataFrame({ "sample_idx": i, "source": src, "target": dst, "score": scores, }) dfs.append(df) return pd.concat(dfs, ignore_index=True) # For mean/sum, handle NaN values properly if reduction == 'sum': scores_agg = torch.nansum(all_scores, dim=0) else: # mean scores_agg = torch.nanmean(all_scores, dim=0) # Convert NaN to None for edges not in mask scores = scores_agg.detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] return pd.DataFrame({ "source": src, "target": dst, "score": scores, }) def _compute_node_attributions( self, x1: torch.Tensor, x2: torch.Tensor, target_idx: Union[int, List[int]], element_mask=None, reduction: str = 'mean', ) -> pd.DataFrame: """Compute node-level occlusion attributions for *f(x₁) − f(x₂)*.""" x1, x2 = x1.to(self.device), x2.to(self.device) # Ensure batch dimension if x1.dim() == 1: x1 = x1.unsqueeze(0) if x2.dim() == 1: x2 = x2.unsqueeze(0) B = x1.size(0) # batch size if isinstance(target_idx, int): target_idx = [target_idx] N = self.model.num_nodes # ------------------------------------------------------------------ # Process element_mask # ------------------------------------------------------------------ if element_mask is not None: if isinstance(element_mask, np.ndarray): element_mask = torch.from_numpy(element_mask) element_mask = element_mask.to(self.device).bool() nodes_to_occlude = torch.where(element_mask)[0] else: nodes_to_occlude = torch.arange(N, device=self.device) num_nodes_to_occlude = len(nodes_to_occlude) if self.verbose: print(f"\nStarting contrastive node occlusion explanation:") print(f" Input shapes: x1={x1.shape}, x2={x2.shape}") print(f" Batch size: {B}") print(f" Target indices: {target_idx}") print(f" Nodes to occlude: {num_nodes_to_occlude} / {N}") print(f" Total node batches to process: {((num_nodes_to_occlude - 1) // self.batch_size) + 1 if num_nodes_to_occlude > 0 else 0}") # ------------------------------------------------------------------ # Process each sample pair # ------------------------------------------------------------------ all_scores = [] for sample_idx in range(B): x1i = x1[sample_idx:sample_idx+1] # (1, N_in) x2i = x2[sample_idx:sample_idx+1] # (1, N_in) if self.verbose and B > 1: print(f"Processing sample {sample_idx + 1}/{B}") # Compute baseline difference (all nodes present) baseline_mask = torch.ones((1, N), device=self.device) baseline_diff = self._compute_diff_node(x1i, x2i, target_idx, baseline_mask) if self.verbose: print(f" Baseline |Δf| = {baseline_diff:.6f}") # Compute occlusion scores in batches occlusion_scores = torch.full((N,), float('nan'), device=self.device) if num_nodes_to_occlude > 0: for start_idx in range(0, num_nodes_to_occlude, self.batch_size): end_idx = min(start_idx + self.batch_size, num_nodes_to_occlude) batch_size_actual = end_idx - start_idx if self.verbose: print(f"Processing node batch {start_idx // self.batch_size + 1}/{((num_nodes_to_occlude - 1) // self.batch_size) + 1}", end='\r') # Get the actual node indices for this batch batch_node_indices = nodes_to_occlude[start_idx:end_idx] # Create batch of masks with one node removed per mask batch_masks = torch.ones((batch_size_actual, N), device=self.device) for i, node_idx in enumerate(batch_node_indices): batch_masks[i, node_idx] = 0.0 # Replicate inputs for batch processing x1_batch = x1i.repeat(batch_size_actual, 1) x2_batch = x2i.repeat(batch_size_actual, 1) x_batch = torch.cat([x1_batch, x2_batch], dim=0) mask_batch = batch_masks.repeat(2, 1) # Forward pass preds = self.model(x_batch, node_mask=mask_batch)[:, target_idx].sum(dim=1) preds_x1 = preds[:batch_size_actual] preds_x2 = preds[batch_size_actual:] # Compute occlusion effects occluded_diffs = (preds_x1 - preds_x2).abs() batch_scores = baseline_diff - occluded_diffs occlusion_scores[batch_node_indices] = batch_scores all_scores.append(occlusion_scores) if self.verbose: print("\nNode occlusion explanation complete.") all_scores = torch.stack(all_scores, dim=0) # (B, N) # ------------------------------------------------------------------ # Package results with reduction # ------------------------------------------------------------------ node_names = np.array(self.model.homo_names) if reduction == 'none': dfs = [] for i in range(B): scores = all_scores[i].detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] df = pd.DataFrame({ "sample_idx": i, "node": node_names, "score": scores, }) dfs.append(df) return pd.concat(dfs, ignore_index=True) # For mean/sum, handle NaN values properly if reduction == 'sum': scores_agg = torch.nansum(all_scores, dim=0) else: # mean scores_agg = torch.nanmean(all_scores, dim=0) # Convert NaN to None for nodes not in mask scores = scores_agg.detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] return pd.DataFrame({ "node": node_names, "score": scores, }) def _compute_diff_edge(self, x1: torch.Tensor, x2: torch.Tensor, target_idx: List[int], mask: torch.Tensor) -> float: """Compute absolute prediction difference with given edge mask.""" pred1 = self.model(x1, edge_mask=mask)[:, target_idx].sum(dim=1) pred2 = self.model(x2, edge_mask=mask)[:, target_idx].sum(dim=1) return (pred1 - pred2).abs().item() def _compute_diff_node(self, x1: torch.Tensor, x2: torch.Tensor, target_idx: List[int], mask: torch.Tensor) -> float: """Compute absolute prediction difference with given node mask.""" pred1 = self.model(x1, node_mask=mask)[:, target_idx].sum(dim=1) pred2 = self.model(x2, node_mask=mask)[:, target_idx].sum(dim=1) return (pred1 - pred2).abs().item()