Source code for gsnn.simulate.graph_comparison

"""
Graph comparison utilities for evaluating shared dependencies between graphs.
"""

from typing import Dict, Tuple, Set, Any, Union
from collections import defaultdict


[docs]class GraphComparison: """ A class for comparing edge index dictionaries to evaluate shared dependencies between input and output nodes. The edge_index_dict is expected to have keys of the form: - (input, to, function) - (function, to, function) - (function, to, output) """ def __init__(self, reference_edge_index_dict: Dict[Tuple[str, str, str], Any]): """ Initialize with a reference edge index dictionary. Args: reference_edge_index_dict: The baseline graph to compare against """ self.reference_edge_index_dict = reference_edge_index_dict self.reference_dependencies = self._extract_dependencies(reference_edge_index_dict) def __call__(self, comparison_edge_index_dict: Dict[Tuple[str, str, str], Any]) -> Dict[str, Union[int, float, Set[Tuple[str, str]]]]: """ Compare a new edge index dictionary against the reference. Args: comparison_edge_index_dict: The graph to compare against the reference Returns: Dictionary with comparison metrics including TP, FP, FN, and TN """ comparison_dependencies = self._extract_dependencies(comparison_edge_index_dict) # Get all possible input-output pairs to calculate TNs all_possible_pairs = self._get_all_possible_pairs(self.reference_edge_index_dict, comparison_edge_index_dict) # Calculate metrics shared_dependencies = self.reference_dependencies.intersection(comparison_dependencies) true_positives = len(shared_dependencies) false_positives = len(comparison_dependencies - self.reference_dependencies) false_negatives = len(self.reference_dependencies - comparison_dependencies) # True negatives: pairs that are non-dependencies in both graphs all_dependencies = self.reference_dependencies.union(comparison_dependencies) true_negatives = len(all_possible_pairs - all_dependencies) total_reference = len(self.reference_dependencies) total_comparison = len(comparison_dependencies) total_possible = len(all_possible_pairs) return { 'true_positives': true_positives, 'false_positives': false_positives, 'false_negatives': false_negatives, 'true_negatives': true_negatives, 'total_reference_dependencies': total_reference, 'total_comparison_dependencies': total_comparison, 'total_possible_pairs': total_possible, 'precision': true_positives / total_comparison if total_comparison > 0 else 0.0, 'recall': true_positives / total_reference if total_reference > 0 else 0.0, 'specificity': true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0.0, 'accuracy': (true_positives + true_negatives) / total_possible if total_possible > 0 else 0.0, 'shared_dependencies': shared_dependencies } def _extract_dependencies(self, edge_index_dict: Dict[Tuple[str, str, str], Any]) -> Set[Tuple[str, str]]: """ Extract all input-to-output dependencies from the edge index dictionary. This method traces paths from input nodes to output nodes through function nodes to identify all dependencies between inputs and outputs. Args: edge_index_dict: Dictionary with edge information Returns: Set of (input_node, output_node) tuples representing dependencies """ # Build adjacency lists for efficient path traversal graph = defaultdict(set) input_nodes = set() output_nodes = set() # Parse the edge index dictionary for (source_type, relation, target_type), edges in edge_index_dict.items(): if hasattr(edges, 'numpy'): edges = edges.numpy() elif hasattr(edges, 'tolist'): edges = edges.tolist() # Handle different edge formats if len(edges) == 2: # [source_indices, target_indices] source_indices, target_indices = edges for src, tgt in zip(source_indices, target_indices): # Create node identifiers that include type information src_node = f"{source_type}_{src}" tgt_node = f"{target_type}_{tgt}" graph[src_node].add(tgt_node) if source_type == 'input': input_nodes.add(src_node) if target_type == 'output': output_nodes.add(tgt_node) # Find all dependencies from inputs to outputs dependencies = set() for input_node in input_nodes: reachable_outputs = self._find_reachable_outputs(graph, input_node, output_nodes) for output_node in reachable_outputs: # Extract the actual node indices from the identifiers input_idx = input_node.split('_', 1)[1] output_idx = output_node.split('_', 1)[1] dependencies.add((input_idx, output_idx)) return dependencies def _find_reachable_outputs(self, graph: Dict[str, Set[str]], start_node: str, output_nodes: Set[str]) -> Set[str]: """ Find all output nodes reachable from a given start node using DFS. Args: graph: Adjacency list representation of the graph start_node: Starting node for the search output_nodes: Set of output nodes to look for Returns: Set of reachable output nodes """ visited = set() reachable_outputs = set() def dfs(node): if node in visited: return visited.add(node) if node in output_nodes: reachable_outputs.add(node) return for neighbor in graph.get(node, set()): dfs(neighbor) dfs(start_node) return reachable_outputs def _get_all_possible_pairs(self, reference_edge_index_dict: Dict[Tuple[str, str, str], Any], comparison_edge_index_dict: Dict[Tuple[str, str, str], Any]) -> Set[Tuple[str, str]]: """ Get all possible input-output pairs from both graphs. Args: reference_edge_index_dict: Reference graph comparison_edge_index_dict: Comparison graph Returns: Set of all possible (input, output) pairs """ all_inputs = set() all_outputs = set() # Extract input and output nodes from both graphs for edge_dict in [reference_edge_index_dict, comparison_edge_index_dict]: for (source_type, relation, target_type), edges in edge_dict.items(): if hasattr(edges, 'numpy'): edges = edges.numpy() elif hasattr(edges, 'tolist'): edges = edges.tolist() if len(edges) == 2: # [source_indices, target_indices] source_indices, target_indices = edges if source_type == 'input': all_inputs.update(str(idx) for idx in source_indices) if target_type == 'output': all_outputs.update(str(idx) for idx in target_indices) # Generate all possible input-output pairs all_possible_pairs = set() for input_node in all_inputs: for output_node in all_outputs: all_possible_pairs.add((input_node, output_node)) return all_possible_pairs
[docs] def get_dependency_details(self, comparison_edge_index_dict: Dict[Tuple[str, str, str], Any]) -> Dict[str, Set]: """ Get detailed information about the dependencies comparison. Args: comparison_edge_index_dict: The graph to compare against the reference Returns: Dictionary with detailed dependency sets """ comparison_dependencies = self._extract_dependencies(comparison_edge_index_dict) return { 'reference_dependencies': self.reference_dependencies, 'comparison_dependencies': comparison_dependencies, 'shared_dependencies': self.reference_dependencies.intersection(comparison_dependencies), 'missing_dependencies': self.reference_dependencies - comparison_dependencies, 'extra_dependencies': comparison_dependencies - self.reference_dependencies }