Source code for gsnn.simulate.nx2pyg

import torch 
from torch_geometric.data import HeteroData
import networkx as nx

[docs]def nx2pyg(G, input_nodes, function_nodes, output_nodes, weight_attr=None): data = HeteroData() # Create edge lists input_to_function_edges = [] function_to_function_edges = [] function_to_output_edges = [] # Categorize edges based on node types for u, v in G.edges: if u in input_nodes and v in function_nodes: input_to_function_edges.append((input_nodes.index(u), function_nodes.index(v))) elif u in function_nodes and v in function_nodes: function_to_function_edges.append((function_nodes.index(u), function_nodes.index(v))) elif u in function_nodes and v in output_nodes: function_to_output_edges.append((function_nodes.index(u), output_nodes.index(v))) # Convert to torch LongTensors input_edge_index = torch.LongTensor(input_to_function_edges).t().contiguous() function_edge_index = torch.LongTensor(function_to_function_edges).t().contiguous() output_edge_index = torch.LongTensor(function_to_output_edges).t().contiguous() if weight_attr is not None: input_to_function_weights = [] function_to_function_weights = [] function_to_output_weights = [] for u, v in G.edges(): if u in input_nodes and v in function_nodes: input_to_function_weights.append(G[u][v][weight_attr]) elif u in function_nodes and v in function_nodes: function_to_function_weights.append(G[u][v][weight_attr]) elif u in function_nodes and v in output_nodes: function_to_output_weights.append(G[u][v][weight_attr]) input_to_function_weights = torch.tensor(input_to_function_weights, dtype=torch.float32) function_to_function_weights = torch.tensor(function_to_function_weights, dtype=torch.float32) function_to_output_weights = torch.tensor(function_to_output_weights, dtype=torch.float32) edge_weight_dict = { ('input', 'to', 'function'): input_to_function_weights, ('function', 'to', 'function'): function_to_function_weights, ('function', 'to', 'output'): function_to_output_weights, } data.edge_weight_dict = edge_weight_dict # Create node name dictionaries input_names = input_nodes function_names = function_nodes output_names = output_nodes # Assign edge index dictionaries data.edge_index_dict = { ('input', 'to', 'function'): input_edge_index, ('function', 'to', 'function'): function_edge_index, ('function', 'to', 'output'): output_edge_index, } # Assign node names dictionaries data.node_names_dict = { 'input': input_names, 'function': function_names, 'output': output_names, } return data
[docs]def pyg2nx(data): G = nx.DiGraph() input_names = data.node_names_dict['input'] function_names = data.node_names_dict['function'] output_names = data.node_names_dict['output'] for src,dst in data.edge_index_dict['input', 'to', 'function'].T: src_name = input_names[src] dst_name = function_names[dst] G.add_edge(src_name, dst_name) for src,dst in data.edge_index_dict['function', 'to', 'function'].T: src_name = function_names[src] dst_name = function_names[dst] G.add_edge(src_name, dst_name) for src,dst in data.edge_index_dict['function', 'to', 'output'].T: src_name = function_names[src] dst_name = output_names[dst] G.add_edge(src_name, dst_name) return G