Source code for gsnn.gsnn.interpret.OcclusionExplainer

import numpy as np
import torch
import copy
import pandas as pd


[docs]class OcclusionExplainer: r"""Edge/node occlusion explainer for single observations. Computes per-edge or per-node attributions for a prediction *f(x)[target_idx]* by systematically removing each element and measuring the change in prediction. For edge-level attributions:: Occ_e = f(x; mask_baseline) - f(x; mask_e_removed) For node-level attributions:: Occ_n = f(x; mask_baseline) - f(x; mask_n_removed) where *mask_baseline* uses all elements present and *mask_element_removed* removes only the specified element (sets mask[element] = 0). **Interpretation of scores** ---------------------------- * ``Occ > 0`` element contributes positively to the prediction * ``Occ < 0`` element inhibits the prediction (removing it increases output) * ``Occ ≈ 0`` element has no impact on the prediction The occlusion approach provides a direct, model-agnostic measure of element importance by directly measuring the effect of completely removing each element. Parameters ---------- model : torch.nn.Module Trained GSNN model (copied and frozen internally). data : torch_geometric.data.Data Graph data object; only used for element names. ignore_cuda : bool, optional (default=False) Force the explainer to run on CPU even if CUDA is available. batch_size : int, optional (default=32) Number of element occlusions to process in parallel. Example ------- >>> explainer = OcclusionExplainer(model, data, batch_size=64) >>> # Edge-level attributions >>> edge_df = explainer.explain(x, target_idx=0, target='edge') >>> edge_df.nlargest(5, 'score') source target score in0 func0 0.42 func0 func3 0.40 func3 out0 0.38 >>> # Node-level attributions >>> node_df = explainer.explain(x, target_idx=0, target='node') >>> node_df.nlargest(5, 'score') >>> # Occlude only a subset of edges >>> edge_mask = np.array([True, False, True, False, True]) # Only occlude edges 0, 2, 4 >>> edge_df = explainer.explain(x, target_idx=0, target='edge', element_mask=edge_mask) >>> # Edges 1 and 3 will have None scores """
[docs] def __init__(self, model, data, ignore_cuda=False, batch_size=32): """Create a new OcclusionExplainer instance.""" self.data = data self.device = 'cuda' if (torch.cuda.is_available() and not ignore_cuda) else 'cpu' model = copy.deepcopy(model) model = model.eval() model = model.to(self.device) self.model = model self.batch_size = batch_size self.E = model.edge_index.size(1) self.N = model.num_nodes
[docs] def explain(self, x, target_idx, element_mask=None, target='edge', reduction='mean'): """Compute edge or node occlusion attributions for *f(x)[target_idx]*. Parameters ---------- x : torch.Tensor (shape: [N_in], [1, N_in], or [B, N_in] for batch) Input feature tensor. Will be moved to appropriate device. target_idx : int Output dimension to explain. element_mask : torch.Tensor or np.ndarray, optional (shape: [E] or [N]) Boolean mask indicating which elements to compute occlusion for. If None, all elements are considered. If provided, only elements where element_mask[i] is True will have occlusion scores computed. target : str, optional (default='edge') Whether to return 'edge' or 'node' level attributions. reduction : str, optional (default='mean') How to aggregate attributions across batch samples: - 'mean': average attributions across samples (default) - 'sum': sum attributions across samples - 'none': return all per-sample attributions (adds 'sample_idx' column) Returns ------- pd.DataFrame If target='edge': columns ['source', 'target', 'score'] for edge attributions. If target='node': columns ['node', 'score'] for node attributions. If reduction='none': additional 'sample_idx' column for batch dimension. Elements not in element_mask will have None scores. """ if target not in ['edge', 'node']: raise ValueError(f"target must be 'edge' or 'node', got '{target}'") if reduction not in ['mean', 'sum', 'none']: raise ValueError(f"reduction must be 'mean', 'sum', or 'none', got '{reduction}'") if target == 'edge': return self._explain_edges(x, target_idx, element_mask, reduction) else: return self._explain_nodes(x, target_idx, element_mask, reduction)
def _explain_edges(self, x, target_idx, element_mask=None, reduction='mean'): """ Compute edge-level attributions using occlusion. Parameters ---------- x : torch.Tensor Input features of shape (N_in,), (1, N_in), or (B, N_in). target_idx : int Index of the target output node to explain. element_mask : torch.Tensor or np.ndarray, optional Boolean mask indicating which edges to compute occlusion for. reduction : str How to aggregate across batch: 'mean', 'sum', or 'none'. Returns ------- pd.DataFrame Columns ['source', 'target', 'score'] for edge attributions. If reduction='none': additional 'sample_idx' column. """ x = x.to(self.device) if x.dim() == 1: x = x.unsqueeze(0) # Ensure batch dimension # ------------------------------------------------------------------ # 1. Process element_mask # ------------------------------------------------------------------ if element_mask is not None: if isinstance(element_mask, np.ndarray): element_mask = torch.from_numpy(element_mask) element_mask = element_mask.to(self.device).bool() edges_to_occlude = torch.where(element_mask)[0] else: edges_to_occlude = torch.arange(self.E, device=self.device) # ------------------------------------------------------------------ # 2. Compute occlusion scores (batched across all samples) # ------------------------------------------------------------------ # Compute baseline prediction (all edges present) baseline_pred = self.model(x)[:, target_idx].detach() # (B,) B = x.size(0) # batch size # Initialize scores with NaN for edges not being occluded occlusion_scores = torch.full((B, self.E), float('nan'), device=self.device) if len(edges_to_occlude) > 0: for start_idx in range(0, len(edges_to_occlude), self.batch_size): end_idx = min(start_idx + self.batch_size, len(edges_to_occlude)) BB = end_idx - start_idx # edge occlusion batch size # Create batch of masks with one edge removed per mask batch_masks = torch.ones((BB, self.E), device=self.device) batch_edge_indices = edges_to_occlude[start_idx:end_idx] # Vectorized mask creation batch_masks[torch.arange(BB, device=self.device), batch_edge_indices] = 0.0 # Replicate input for batch processing: (B, N_in) -> (BB*B, N_in) x_batch = x.unsqueeze(0).repeat(BB, 1, 1) # (BB, B, N_in) x_batch = x_batch.view(-1, x.size(1)) # (BB*B, N_in) # Expand masks for all samples: (BB, E) -> (BB*B, E) batch_masks_expanded = batch_masks.unsqueeze(1).repeat(1, B, 1) # (BB, B, E) batch_masks_expanded = batch_masks_expanded.view(-1, self.E) # (BB*B, E) # Forward pass preds = self.model(x_batch, edge_mask=batch_masks_expanded)[:, target_idx] # (BB*B,) preds = preds.view(BB, B) # (BB, B) # Compute occlusion effects: baseline - occluded # baseline_pred is (B,), broadcasts to (BB, B) batch_scores = baseline_pred.unsqueeze(0) - preds # (BB, B) # Store scores: transpose to (B, BB) for proper column assignment occlusion_scores[:, batch_edge_indices] = batch_scores.T # ------------------------------------------------------------------ # 3. Package results with reduction # ------------------------------------------------------------------ src, dst = np.array(self.model.homo_names)[self.model.edge_index.detach().cpu().numpy()] if reduction == 'none': # Return per-sample attributions dfs = [] for i in range(B): scores = occlusion_scores[i].detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] df = pd.DataFrame({ 'sample_idx': i, 'source': src, 'target': dst, 'score': scores }) dfs.append(df) return pd.concat(dfs, ignore_index=True) # For mean/sum, handle NaN values properly if reduction == 'sum': scores_agg = torch.nansum(occlusion_scores, dim=0) else: # mean scores_agg = torch.nanmean(occlusion_scores, dim=0) # Convert NaN to None for edges not in mask scores = scores_agg.detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] return pd.DataFrame({ 'source': src, 'target': dst, 'score': scores }) def _explain_nodes(self, x, target_idx, element_mask=None, reduction='mean'): """ Compute node-level attributions using occlusion. Parameters ---------- x : torch.Tensor Input features of shape (N_in,), (1, N_in), or (B, N_in). target_idx : int Index of the target output node to explain. element_mask : torch.Tensor or np.ndarray, optional Boolean mask indicating which nodes to compute occlusion for. reduction : str How to aggregate across batch: 'mean', 'sum', or 'none'. Returns ------- pd.DataFrame Columns ['node', 'score'] for node attributions. If reduction='none': additional 'sample_idx' column. """ x = x.to(self.device) if x.dim() == 1: x = x.unsqueeze(0) # Ensure batch dimension B = x.size(0) # batch size # ------------------------------------------------------------------ # 1. Process element_mask # ------------------------------------------------------------------ if element_mask is not None: if isinstance(element_mask, np.ndarray): element_mask = torch.from_numpy(element_mask) element_mask = element_mask.to(self.device).bool() nodes_to_occlude = torch.where(element_mask)[0] else: nodes_to_occlude = torch.arange(self.N, device=self.device) # ------------------------------------------------------------------ # 2. Compute occlusion scores (batched across all samples) # ------------------------------------------------------------------ # Compute baseline prediction (all nodes present) baseline_pred = self.model(x)[:, target_idx].detach() # (B,) # Initialize scores with NaN for nodes not being occluded occlusion_scores = torch.full((B, self.N), float('nan'), device=self.device) if len(nodes_to_occlude) > 0: for start_idx in range(0, len(nodes_to_occlude), self.batch_size): end_idx = min(start_idx + self.batch_size, len(nodes_to_occlude)) NN = end_idx - start_idx # node occlusion batch size # Create batch of masks with one node removed per mask batch_masks = torch.ones((NN, self.N), device=self.device) batch_node_indices = nodes_to_occlude[start_idx:end_idx] # Vectorized mask creation batch_masks[torch.arange(NN, device=self.device), batch_node_indices] = 0.0 # Replicate input for batch processing: (B, N_in) -> (NN*B, N_in) x_batch = x.unsqueeze(0).repeat(NN, 1, 1) # (NN, B, N_in) x_batch = x_batch.view(-1, x.size(1)) # (NN*B, N_in) # Expand masks for all samples: (NN, N) -> (NN*B, N) batch_masks_expanded = batch_masks.unsqueeze(1).repeat(1, B, 1) # (NN, B, N) batch_masks_expanded = batch_masks_expanded.view(-1, self.N) # (NN*B, N) # Forward pass preds = self.model(x_batch, node_mask=batch_masks_expanded)[:, target_idx] # (NN*B,) preds = preds.view(NN, B) # (NN, B) # Compute occlusion effects: baseline - occluded # baseline_pred is (B,), broadcasts to (NN, B) batch_scores = baseline_pred.unsqueeze(0) - preds # (NN, B) # Store scores: transpose to (B, NN) for proper column assignment occlusion_scores[:, batch_node_indices] = batch_scores.T # ------------------------------------------------------------------ # 3. Package results with reduction # ------------------------------------------------------------------ node_names = np.array(self.model.homo_names) if reduction == 'none': # Return per-sample attributions dfs = [] for i in range(B): scores = occlusion_scores[i].detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] df = pd.DataFrame({ 'sample_idx': i, 'node': node_names, 'score': scores }) dfs.append(df) return pd.concat(dfs, ignore_index=True) # For mean/sum, handle NaN values properly if reduction == 'sum': scores_agg = torch.nansum(occlusion_scores, dim=0) else: # mean scores_agg = torch.nanmean(occlusion_scores, dim=0) # Convert NaN to None for nodes not in mask scores = scores_agg.detach().cpu().numpy() scores = [None if np.isnan(score) else score for score in scores] return pd.DataFrame({ 'node': node_names, 'score': scores })