Source code for gsnn.gsnn.proc.construct

"""
construct.py - Build pruned constraint networks for the GSNN model
=================================================================

This helper stitches together several edge lists into a directed
heterogeneous graph, removes nodes that do **not** lie on any path from
an *input* to an *output*, and returns the result as a PyTorch-Geometric
`HeteroData` object ready for use in GSNN.

Accepted edge tables (all `pd.DataFrame` with columns `src`, `dst`)
------------------------------------------------------------------
1. **input_edges**    - critical inputs  →  function nodes
2. **mediator_edges** - (optional) mediator inputs → function nodes.
   Mediators are *only* retained if they target a function node that
   survives the pruning step.
3. **function_edges** - function  →  function (latent logic)
4. **output_edges**   - function  →  output

ASCII toy graph
---------------
```
input_A      mediator_M
   │             │
   ▼             │
 func_X ◀────────┘


 out_Y
```
`mediator_M` is kept **only** if `func_X` remains on a viable
`input_A → … → out_Y` path.

Quick usage example
-------------------
```python
from gsnn.proc.construct import GSNNNetworkConstructor

builder = GSNNNetworkConstructor(depth=10, verbose=True)
data = builder.build(input_edges, function_edges,
                     output_edges, mediator_edges)
print(data.graph_summary)
```

Returned attributes
-------------------
* `data.node_names_dict` - mapping of node types to names
* `data.edge_index_dict` - edge indices (PyTorch tensors)
* `data.graph_summary`   - pruning statistics & graph metrics
* `data.build_args`      - parameters used to build the graph
"""

import pandas as pd
import numpy as np
import networkx as nx
import torch
from torch_geometric.data import HeteroData
from .subset import subset_graph




[docs]class GSNNNetworkConstructor: def __init__(self, depth=10, verbose=True): self.depth = depth self.verbose = verbose
[docs] def build(self, input_edges, function_edges, output_edges, mediator_edges=None, input_names=None, output_names=None, function_names=None, mediator_names=None): """Construct and filter constraint network for GSNN model. Behavior of fixed node name arguments: - input_names, mediator_names, function_names, output_names specify nodes that MUST be included in the final node lists, even if they have no incident edges. - Additional nodes of each type that are inferred from the provided edge tables are also included. In other words, the final node sets are supersets of the fixed names. - Inputs and mediators are combined into a single 'input' type in the order input_names + mediator_names, followed by any additional discovered inputs. Args: input_edges: DataFrame with columns 'src', 'dst' for input→function edges function_edges: DataFrame with columns 'src', 'dst' for function→function edges output_edges: DataFrame with columns 'src', 'dst' for function→output edges mediator_edges: Optional DataFrame with columns 'src', 'dst' for mediator→function edges (only retained if they target a function node retained in the pruned graph) input_names: Optional list of input node names to force-include mediator_names: Optional list of mediator node names to force-include (combined with inputs) function_names: Optional list of function node names to force-include output_names: Optional list of output node names to force-include Returns: HeteroData object with filtered network structure and metadata """ if self.verbose: print('building candidate network...') # Build complete network G = nx.DiGraph() # Track node types and initial counts inputs, mediators, functions, outputs = set(), set(), set(), set() # Add edges and collect node types for _, row in input_edges.iterrows(): G.add_edge(f"input_{row['src']}", f"func_{row['dst']}") inputs.add(row['src']) functions.add(row['dst']) if mediator_edges is not None: mediators.update(mediator_edges['src'].tolist()) functions.update(mediator_edges['dst'].tolist()) for _, row in function_edges.iterrows(): G.add_edge(f"func_{row['src']}", f"func_{row['dst']}") functions.add(row['src']) functions.add(row['dst']) for _, row in output_edges.iterrows(): G.add_edge(f"func_{row['src']}", f"out_{row['dst']}") functions.add(row['src']) outputs.add(row['dst']) # If fixed name lists are provided, ensure they contribute to initial counts if input_names is not None: inputs.update(list(input_names)) if mediator_names is not None: mediators.update(list(mediator_names)) if function_names is not None: functions.update(list(function_names)) if output_names is not None: outputs.update(list(output_names)) initial_counts = {'input': len(inputs), 'mediator': len(mediators), 'function': len(functions), 'output': len(outputs)} # Filter: keep only nodes on paths from inputs to outputs # Only include roots/leaves that actually exist in G to avoid BFS on missing nodes roots = [f"input_{i}" for i in inputs if f"input_{i}" in G] leaves = [f"out_{o}" for o in outputs if f"out_{o}" in G] G_filtered = subset_graph(G, self.depth, roots, leaves, verbose=self.verbose) if self.verbose: print() # Add mediator nodes and edges after filtering, but only if they target a retained function node if mediator_edges is not None: if self.verbose: print('adding mediator edges...') for _, row in mediator_edges.iterrows(): func_node = f"func_{row['dst']}" if func_node in G_filtered: med_node = f"med_{row['src']}" G_filtered.add_edge(med_node, func_node) # Helper to deduplicate while preserving order def _dedupe_preserve_order(sequence): seen = set() result = [] for item in sequence: if item not in seen: seen.add(item) result.append(item) return result # Build node name lists honoring any fixed name inputs node_names = {'input': [], 'function': [], 'output': []} # Inputs (combine inputs + mediators if provided; maintain order input_names + mediator_names) discovered_inputs = [] for node in G_filtered.nodes(): if node.startswith('input_') or node.startswith('med_'): discovered_inputs.append(node.split('_', 1)[1]) discovered_inputs = _dedupe_preserve_order(discovered_inputs) fixed_inputs = [] if input_names is not None or mediator_names is not None: fixed_inputs = (list(input_names) if input_names is not None else []) + \ (list(mediator_names) if mediator_names is not None else []) fixed_inputs = _dedupe_preserve_order(fixed_inputs) node_names['input'] = _dedupe_preserve_order(list(fixed_inputs) + [n for n in discovered_inputs if n not in fixed_inputs]) else: node_names['input'] = discovered_inputs # Functions discovered_functions = [] for node in G_filtered.nodes(): if node.startswith('func_'): discovered_functions.append(node.split('_', 1)[1]) discovered_functions = _dedupe_preserve_order(discovered_functions) if function_names is not None: fn_fixed = _dedupe_preserve_order(list(function_names)) node_names['function'] = _dedupe_preserve_order(list(fn_fixed) + [n for n in discovered_functions if n not in fn_fixed]) else: node_names['function'] = discovered_functions # Outputs discovered_outputs = [] for node in G_filtered.nodes(): if node.startswith('out_'): discovered_outputs.append(node.split('_', 1)[1]) discovered_outputs = _dedupe_preserve_order(discovered_outputs) if output_names is not None: out_fixed = _dedupe_preserve_order(list(output_names)) node_names['output'] = _dedupe_preserve_order(list(out_fixed) + [n for n in discovered_outputs if n not in out_fixed]) else: node_names['output'] = discovered_outputs # Build index maps from finalized node lists node_to_idx = { 'input': {name: idx for idx, name in enumerate(node_names['input'])}, 'function': {name: idx for idx, name in enumerate(node_names['function'])}, 'output': {name: idx for idx, name in enumerate(node_names['output'])}, } # Build edge tensors edges = {('input', 'to', 'function'): [], ('function', 'to', 'function'): [], ('function', 'to', 'output'): []} for src, dst in G_filtered.edges(): if (src.startswith('input_') or src.startswith('med_')) and dst.startswith('func_'): src_idx = node_to_idx['input'][src.split('_', 1)[1]] dst_idx = node_to_idx['function'][dst.split('_', 1)[1]] edges[('input', 'to', 'function')].append([src_idx, dst_idx]) elif src.startswith('func_') and dst.startswith('func_'): src_idx = node_to_idx['function'][src.split('_', 1)[1]] dst_idx = node_to_idx['function'][dst.split('_', 1)[1]] edges[('function', 'to', 'function')].append([src_idx, dst_idx]) elif src.startswith('func_') and dst.startswith('out_'): src_idx = node_to_idx['function'][src.split('_', 1)[1]] dst_idx = node_to_idx['output'][dst.split('_', 1)[1]] edges[('function', 'to', 'output')].append([src_idx, dst_idx]) # Create HeteroData container and attach metadata data = HeteroData() data.node_names_dict = node_names data.edge_index_dict = {k: torch.tensor(v, dtype=torch.long).T if v else torch.empty((2, 0), dtype=torch.long) for k, v in edges.items()} # Calculate summary statistics n_nodes = sum(len(v) for v in node_names.values()) n_edges = sum(e.shape[1] for e in data.edge_index_dict.values()) density = n_edges / (n_nodes * (n_nodes - 1)) if n_nodes > 1 else 0 avg_degree = 2 * n_edges / n_nodes if n_nodes > 0 else 0 clustering = nx.average_clustering(G_filtered.to_undirected()) if G_filtered.nodes() else 0 # Count removed nodes (use finalized node lists for inputs/functions/outputs) final_inputs = len(node_names['input']) # Mediators are combined with inputs; estimate mediator inclusion from provided names and discovered 'med_' nodes if mediator_names is not None: final_mediators = len(_dedupe_preserve_order(list(mediator_names))) else: final_mediators = len([n for n in G_filtered.nodes() if n.startswith('med_')]) final_functions = len(node_names['function']) final_outputs = len(node_names['output']) # Calculate isolated nodes (nodes with no incident edges) isolated_inputs = 0 isolated_outputs = 0 # Count isolated inputs (inputs with no outgoing edges) input_edges = edges[('input', 'to', 'function')] connected_inputs = set() for edge in input_edges: connected_inputs.add(edge[0]) # edge[0] is the input node index isolated_inputs = final_inputs - len(connected_inputs) # Count isolated outputs (outputs with no incoming edges) output_edges = edges[('function', 'to', 'output')] connected_outputs = set() for edge in output_edges: connected_outputs.add(edge[1]) # edge[1] is the output node index isolated_outputs = final_outputs - len(connected_outputs) summary = { 'inputs_included': final_inputs, 'inputs_removed': max(initial_counts['input'] - final_inputs, 0), 'mediators_included': final_mediators, 'mediators_removed': max(initial_counts['mediator'] - final_mediators, 0), 'functions_included': final_functions, 'functions_removed': max(initial_counts['function'] - final_functions, 0), 'outputs_included': final_outputs, 'outputs_removed': max(initial_counts['output'] - final_outputs, 0), 'isolated_inputs': isolated_inputs, 'isolated_outputs': isolated_outputs, 'total_nodes': n_nodes, 'total_edges': n_edges, 'density': density, 'avg_degree': avg_degree, 'avg_clustering': clustering } data.graph_summary = summary data.build_args = {'depth': self.depth, 'has_mediators': (mediator_edges is not None) or (mediator_names is not None and len(mediator_names) > 0)} if self.verbose: print(f"\n{'='*50}") print("Network Construction Summary:") print(f"{'='*50}") print(f"Inputs: {final_inputs} included, {summary['inputs_removed']} removed") print(f"Mediators: {final_mediators} included, {summary['mediators_removed']} removed") print(f"Functions: {final_functions} included, {summary['functions_removed']} removed") print(f"Outputs: {final_outputs} included, {summary['outputs_removed']} removed") print(f"Isolated nodes: {isolated_inputs} inputs, {isolated_outputs} outputs") print(f"\nGraph Statistics:") print(f" Nodes: {n_nodes}") print(f" Edges: {n_edges}") print(f" Density: {density:.4f}") print(f" Avg Degree: {avg_degree:.2f}") print(f" Avg Clustering: {clustering:.4f}") print(f"{'='*50}\n") return data