{ "cells": [ { "cell_type": "markdown", "id": "91e5d15f", "metadata": {}, "source": [ "# Post-hoc edge inference via shared-embedding link prediction (node2vec)\n", "\n", "This notebook demonstrates **post-hoc, shared-embedding** function -> function edge inference using `MagnitudeEdgeKGE` (now a node2vec-style link predictor; the class name is kept for backward compatibility).\n", "\n", "Workflow:\n", "1. Train a GSNN on a partial graph.\n", "2. Run `MagnitudeEdgeInferer` to accumulate activation/gradient magnitude correlations.\n", "3. Threshold MEI scores into inferred positive edges.\n", "4. Pool inferred edges with kept-graph edges into one augmented directed graph.\n", "5. Learn a single shared node embedding table by skip-gram with negative sampling on random walks.\n", "6. Score held-out edges by the dot product of node embeddings.\n", "\n", "Same converging-tier DAG setup as notebooks 13-16 (**12 inputs -> 24 function nodes -> 12 outputs**, **16 held-out** edges).\n", "\n", "No `complex2` dependency." ] }, { "cell_type": "code", "execution_count": 75, "id": "2389a1df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import networkx as nx\n", "import torch\n", "\n", "from gsnn.models.GSNN import GSNN\n", "from gsnn.simulate.nx2pyg import nx2pyg\n", "from gsnn.simulate.simulate import simulate\n", "from gsnn.optim.MagnitudeEdgeInferer import MagnitudeEdgeInferer\n", "from gsnn.optim.MagnitudeEdgeKGE import MagnitudeEdgeKGE\n", "\n", "from sklearn.metrics import roc_auc_score, roc_curve\n", "\n", "torch.manual_seed(0)\n", "np.random.seed(0)\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "2aeb5514", "metadata": {}, "source": [ "## Build ground-truth graph and simulate data" ] }, { "cell_type": "code", "execution_count": 76, "id": "c0719ba9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "functions: 24 | held-out val/test: 8/8\n" ] } ], "source": [ "def build_convergence_graph(n_tier_a=6):\n", " n = n_tier_a\n", " G = nx.DiGraph()\n", " func_func_edges_TRUE = []\n", " input_nodes = [f'in{i}' for i in range(n)]\n", " tier_a = [f'f{i}' for i in range(n)]\n", " tier_b = [f'f{n + k}' for k in range(n - 1)]\n", " tier_c = [f'f{2 * n - 1}']\n", " function_nodes = tier_a + tier_b + tier_c\n", " output_nodes = [f'o{k}' for k in range(n)]\n", " for i, u in enumerate(input_nodes):\n", " G.add_edge(u, tier_a[i])\n", " for k in range(n - 1):\n", " b = tier_b[k]\n", " for parent in (tier_a[k], tier_a[k + 1]):\n", " G.add_edge(parent, b)\n", " func_func_edges_TRUE.append((parent, b))\n", " sink = tier_c[0]\n", " for b in tier_b:\n", " G.add_edge(b, sink)\n", " func_func_edges_TRUE.append((b, sink))\n", " for k, b in enumerate(tier_b):\n", " G.add_edge(b, output_nodes[k])\n", " G.add_edge(sink, output_nodes[n - 1])\n", " return G, input_nodes, function_nodes, output_nodes, func_func_edges_TRUE\n", "\n", "def default_held_out_edges(n_tier_a, b2sink_stride=2):\n", " n = n_tier_a\n", " sink = f'f{2 * n - 1}'\n", " held = [(f'f{k}', f'f{n + k}') for k in range(1, n - 1)]\n", " held += [(f'f{n + k}', sink) for k in range(0, n - 1, b2sink_stride)]\n", " return held\n", "\n", "N_TIER_A = 12\n", "G, input_nodes, function_nodes, output_nodes, func_func_edges_TRUE = build_convergence_graph(N_TIER_A)\n", "N_FUNC = len(function_nodes)\n", "HELD_OUT_EDGES = default_held_out_edges(N_TIER_A)\n", "\n", "x_train, x_test, y_train, y_test = simulate(\n", " G, n_train=2000, n_test=500,\n", " input_nodes=input_nodes, output_nodes=output_nodes,\n", " noise_scale=0.15, special_functions=None,\n", ")\n", "x_train = torch.tensor(x_train, dtype=torch.float32).to(device)\n", "x_test = torch.tensor(x_test, dtype=torch.float32).to(device)\n", "y_train = torch.tensor(y_train, dtype=torch.float32).to(device)\n", "y_test = torch.tensor(y_test, dtype=torch.float32).to(device)\n", "y_mu, y_std = y_train.mean(0), y_train.std(0)\n", "y_train = (y_train - y_mu) / (y_std + 1e-8)\n", "y_test = (y_test - y_mu) / (y_std + 1e-8)\n", "\n", "held_out_set = set(HELD_OUT_EDGES)\n", "G_partial = G.copy()\n", "G_partial.remove_edges_from(HELD_OUT_EDGES)\n", "data = nx2pyg(G_partial, input_nodes, function_nodes, output_nodes)\n", "kept_edges = [e for e in func_func_edges_TRUE if e not in held_out_set]\n", "kept_ff_set = set(kept_edges)\n", "\n", "sink = f'f{2 * N_TIER_A - 1}'\n", "left_merge = [e for e in HELD_OUT_EDGES if e[1] != sink]\n", "b2sink = [e for e in HELD_OUT_EDGES if e[1] == sink]\n", "rng = np.random.default_rng(0)\n", "rng.shuffle(left_merge); rng.shuffle(b2sink)\n", "edges_val = left_merge[:len(left_merge)//2] + b2sink[:len(b2sink)//2]\n", "edges_test = left_merge[len(left_merge)//2:] + b2sink[len(b2sink)//2:]\n", "held_out_benchmark = set(edges_val) | set(edges_test)\n", "\n", "print(f'functions: {N_FUNC} | held-out val/test: {len(edges_val)}/{len(edges_test)}')" ] }, { "cell_type": "markdown", "id": "78c12d13", "metadata": {}, "source": [ "## Train GSNN (no auxiliary edge inference)" ] }, { "cell_type": "code", "execution_count": 77, "id": "ea5836de", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 1 | test MSE 0.7987\n", "epoch 10 | test MSE 0.4809\n", "epoch 20 | test MSE 0.4842\n", "epoch 30 | test MSE 0.4804\n" ] } ], "source": [ "BATCH_SIZE = 64\n", "\n", "model_kwargs = dict(\n", " channels=8, layers=6, share_layers=False, bias=True,\n", " add_function_self_edges=True, norm='groupbatch', dropout=0.,\n", " nonlin=torch.nn.ELU, node_mlp=False, checkpoint=False,\n", ")\n", "\n", "model = GSNN(data.edge_index_dict, data.node_names_dict, **model_kwargs).to(device)\n", "\n", "train_loader = torch.utils.data.DataLoader(\n", " torch.utils.data.TensorDataset(x_train, y_train),\n", " batch_size=BATCH_SIZE, shuffle=True, drop_last=True,\n", ")\n", "infer_loader = torch.utils.data.DataLoader(\n", " torch.utils.data.TensorDataset(x_train, y_train),\n", " batch_size=BATCH_SIZE, shuffle=False, drop_last=True,\n", ")\n", "gsnn_optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0)\n", "crit = torch.nn.MSELoss()\n", "\n", "n_epochs = 30\n", "for epoch in range(n_epochs):\n", " model.train()\n", " for x_batch, y_batch in train_loader:\n", " gsnn_optim.zero_grad()\n", " loss = crit(model(x_batch), y_batch)\n", " loss.backward()\n", " gsnn_optim.step()\n", " if epoch == 0 or (epoch + 1) % 10 == 0 or epoch == n_epochs - 1:\n", " model.eval()\n", " with torch.no_grad():\n", " mse = crit(model(x_test), y_test).item()\n", " print(f'epoch {epoch+1:2d} | test MSE {mse:.4f}')" ] }, { "cell_type": "markdown", "id": "addf9545", "metadata": {}, "source": [ "## Fit MagnitudeEdgeInferer and inspect inferred positives" ] }, { "cell_type": "code", "execution_count": 81, "id": "e74ba2c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MEI samples: 1984\n", "\n", "Top 10 inferred edges by corr (non-kept):\n" ] }, { "data": { "text/html": [ "
| \n", " | src_func | \n", "dst_func | \n", "corr | \n", "p_value | \n", "q_value | \n", "
|---|---|---|---|---|---|
| 0 | \n", "f4 | \n", "f16 | \n", "0.889384 | \n", "0.0 | \n", "0.0 | \n", "
| 1 | \n", "f2 | \n", "f14 | \n", "0.881459 | \n", "0.0 | \n", "0.0 | \n", "
| 2 | \n", "f2 | \n", "f3 | \n", "0.878537 | \n", "0.0 | \n", "0.0 | \n", "
| 3 | \n", "f10 | \n", "f22 | \n", "0.876956 | \n", "0.0 | \n", "0.0 | \n", "
| 4 | \n", "f1 | \n", "f13 | \n", "0.866949 | \n", "0.0 | \n", "0.0 | \n", "
| 5 | \n", "f18 | \n", "f8 | \n", "0.863582 | \n", "0.0 | \n", "0.0 | \n", "
| 6 | \n", "f3 | \n", "f15 | \n", "0.863015 | \n", "0.0 | \n", "0.0 | \n", "
| 7 | \n", "f18 | \n", "f19 | \n", "0.862077 | \n", "0.0 | \n", "0.0 | \n", "
| 8 | \n", "f10 | \n", "f11 | \n", "0.861264 | \n", "0.0 | \n", "0.0 | \n", "
| 9 | \n", "f4 | \n", "f5 | \n", "0.857633 | \n", "0.0 | \n", "0.0 | \n", "