[20]:
from matplotlib import pyplot as plt
import numpy as np
import networkx as nx
import torch
import copy
import pandas as pd
from gsnn.models.GSNN import GSNN
from gsnn.simulate.nx2pyg import nx2pyg
from gsnn.simulate.simulate import simulate
from gsnn.models.GSNN import GSNN
from sklearn.metrics import r2_score
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests
from gsnn.optim.OutputEdgeInferer import OutputEdgeInferer
# for reproducibility
torch.manual_seed(0)
np.random.seed(0)
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
Inferring output edges
[21]:
# Create a simple directed graph with 3 inputs, 3 outputs, and 5 function nodes
G = nx.DiGraph()
# Add input nodes, function nodes, and output nodes
input_nodes = ['in0', 'in1', 'in2']
function_nodes = ['func0', 'func1', 'func2', 'func3', 'func4']
output_nodes = ['out0', 'out1', 'out2', 'out3']
# Add edges from input nodes to function nodes
G.add_edges_from([('in0', 'func0'), ('in1', 'func1'), ('in2', 'func2'), ('in2', 'func1')])
# Add edges between function nodes
G.add_edges_from([('func0', 'func3'), ('func1', 'func4'), ('func2', 'func3'), ('func0', 'func4')])
# Add edges from function nodes to output nodes
G.add_edges_from([('func3', 'out0'), ('func4', 'out1'), ('func3', 'out2'), ('func1', 'out3')])
# Define positions for each node for plotting
pos = {
'in0': (-2, 2), 'in1': (0, 2), 'in2': (2, 2),
'func0': (-2, 1), 'func1': (0, 1), 'func2': (2, 1),
'func3': (-1, 0), 'func4': (1, 0),
'out0': (-2, -1), 'out1': (0, -1), 'out2': (2, -1),
'out3': (3, -1)
}
# Plot the graph
plt.figure(figsize=(8, 6))
nx.draw_networkx(G, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=10, arrowstyle='->', arrowsize=20)
plt.title("Dummy Graph: 3 Inputs, 3 Outputs, 5 Function Nodes")
plt.show()
[22]:
special_functions = None
x_train, x_test, y_train, y_test = simulate(G, n_train=100, n_test=100, input_nodes=input_nodes, output_nodes=output_nodes,
special_functions=special_functions, noise_scale=0.15)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
x_test = torch.tensor(x_test, dtype=torch.float32).to(device)
y_train = torch.tensor(y_train, dtype=torch.float32).to(device)
y_test = torch.tensor(y_test, dtype=torch.float32).to(device)
y_mu = y_train.mean(0); y_std = y_train.std(0)
y_train = (y_train - y_mu)/(y_std + 1e-8)
y_test = (y_test - y_mu)/(y_std + 1e-8)
[23]:
data = nx2pyg(G, input_nodes, function_nodes, output_nodes)
# simulate a missing edge (in2, func1)
edge_index_dict_TRUE = copy.deepcopy(data.edge_index_dict)
# remove edges func1 to out3
#remove = ('func1', 'out3')
remove = [('func4', 'out1'), ('func3', 'out2'), ('func1', 'out3')]
for edge in remove:
src, dst = data.edge_index_dict['function', 'to', 'output'].clone()
mask = ~((src == data.node_names_dict['function'].index(edge[0])) & (dst == data.node_names_dict['output'].index(edge[1])))
data.edge_index_dict['function', 'to', 'output'] = data.edge_index_dict['function', 'to', 'output'][:, mask]
data.edge_index_dict['function', 'to', 'output']
[23]:
tensor([[3],
[0]])
[24]:
model_kwargs = {'channels': 3,
'layers': 5,
'share_layers': False,
'bias': True,
'add_function_self_edges': True,
'norm': 'layer',
'dropout': 0.,
'nonlin': torch.nn.ELU}
[25]:
# train GSNN with a missing input edge
model = GSNN(data.edge_index_dict,
data.node_names_dict,
**model_kwargs).to(device)
print('n params', sum([p.numel() for p in model.parameters()]))
optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
crit = torch.nn.MSELoss()
losses_gsnn = []
for i in range(1000):
optim.zero_grad()
yhat = model(x_train)
loss = crit(y_train, yhat)
loss.backward()
optim.step()
print(f'iter: {i} | loss: {loss.item():.3f}',end='\r')
model = model.eval()
with torch.inference_mode():
yhat_test = model(x_test)
loss_test = crit(y_test, yhat_test)
r2_test = r2_score(y_test.detach().cpu().numpy(), yhat_test.detach().cpu().numpy())
print(f'test loss: {loss_test.item():.3f} | test r2: {r2_test:.3f}')
n params 6320
test loss: 0.802 | test r2: 0.193
[26]:
OEI = OutputEdgeInferer(data,
model.channels*model.layers,
lr=1e-2,
wd=0,
epochs=2000,
use_batchnorm=False,
bn_affine=False,
tol=1e-4,
patience=5,
agg='all')
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=25, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=10, shuffle=True)
losses = OEI.fit(train_loader, model)
plt.figure(figsize=(3,3))
plt.plot(losses)
plt.show()
Fitting OutputEdgeInferer on cpu...
# parameters: 300
epoch 0 loss: 6.592256784439087592]
epoch 1 loss: 3.7812628746032715125]
epoch 2 loss: 2.2843069434165955716]
epoch 3 loss: 1.4881707131862644565]
epoch 4 loss: 1.1564314365386963158]
epoch 5 loss: 1.0034881234169006227]
epoch 6 loss: 0.9189485013484955614]
epoch 7 loss: 0.8438188582658768258]
epoch 8 loss: 0.7607485800981522481]
epoch 9 loss: 0.689739748835563798]]
epoch 10 loss: 0.634756892919540457]
epoch 11 loss: 0.6025913581252098326]
epoch 12 loss: 0.580315738916397137]
epoch 13 loss: 0.565174296498298624]
epoch 14 loss: 0.551893420517444689]
epoch 15 loss: 0.537574887275695849]]
epoch 16 loss: 0.528227806091308661]]
epoch 17 loss: 0.521082691848278463]
epoch 18 loss: 0.5133229866623878557]
epoch 19 loss: 0.507903024554252647]
epoch 20 loss: 0.5013037398457527794]
epoch 21 loss: 0.49570398777723316]]]
epoch 22 loss: 0.49268320947885513]5]
epoch 23 loss: 0.487835124135017447]]
epoch 24 loss: 0.483576752245426203]
epoch 25 loss: 0.479814417660236362]]
epoch 26 loss: 0.4743735417723655743]
epoch 27 loss: 0.471166893839836144]
epoch 28 loss: 0.467776261270046231]
epoch 29 loss: 0.4650599583983421384]
epoch 30 loss: 0.4616268500685692983]
epoch 31 loss: 0.458733379840850832]]
epoch 32 loss: 0.4561613127589226185]
epoch 33 loss: 0.453942283987998961]]
epoch 34 loss: 0.4513300210237503884]
epoch 35 loss: 0.449594140052795436]]
epoch 36 loss: 0.4482358321547508494]
epoch 37 loss: 0.447312779724597935]]
epoch 38 loss: 0.4441472738981247483]
epoch 39 loss: 0.442113302648067536]
epoch 40 loss: 0.4426162317395217634]
epoch 41 loss: 0.441355355083942407]]
epoch 42 loss: 0.4399544075131416367]
epoch 43 loss: 0.4389619007706642616]
epoch 44 loss: 0.436575308442115806]]
epoch 45 loss: 0.434292547404766172]
epoch 46 loss: 0.433519609272489812]]
epoch 47 loss: 0.4310928732156753566]
epoch 48 loss: 0.430597767233848575]]
epoch 49 loss: 0.430324546992778853]]
epoch 50 loss: 0.4305733814835548474]
epoch 51 loss: 0.429983556270599372]
epoch 52 loss: 0.4272886067628860595]
epoch 53 loss: 0.4282709211111069436]
epoch 54 loss: 0.425594791769981455]]
epoch 55 loss: 0.426970899105072069]]
epoch 56 loss: 0.424426831305027343]]
epoch 57 loss: 0.4222258999943733786]
epoch 58 loss: 0.422559231519699125]]
epoch 59 loss: 0.4218085557222366344]
epoch 60 loss: 0.419985666871070861]]
epoch 61 loss: 0.428239732980728159]
epoch 62 loss: 0.4216015115380287756]
epoch 63 loss: 0.420719370245933536]]
epoch 64 loss: 0.419962458312511444]
epoch 65 loss: 0.4176722019910812483]
epoch 66 loss: 0.41853245347738266]]]
epoch 67 loss: 0.4181315749883652545]
epoch 68 loss: 0.418737426400184634]
epoch 69 loss: 0.4175262674689293464]
epoch 70 loss: 0.416334822773933488]]
epoch 71 loss: 0.415902137756347663]]
epoch 72 loss: 0.4145658984780311604]
epoch 73 loss: 0.4143724292516708445]
epoch 74 loss: 0.415900424122810363]
epoch 75 loss: 0.413796275854110766]]
epoch 76 loss: 0.414242066442966469]]
epoch 77 loss: 0.414279259741306393]]
epoch 78 loss: 0.4166224300861358696]
epoch 79 loss: 0.4123078212141990706]
epoch 80 loss: 0.4137431904673576434]
epoch 81 loss: 0.413618378341197977]]
epoch 82 loss: 0.412375755608081871]]
epoch 83 loss: 0.412224799394607541]]
[batch 3/4 loss: 0.3387291729450226]]
[27]:
res = OEI.evaluate(test_loader, model.eval())
res.sort_values(by='q_value', ascending=True)
Evaluating OutputEdgeInferer on cpu...
[batch 9/10]
/home/teddy/local/GSNN/gsnn/optim/OutputEdgeInferer.py:358: RuntimeWarning: invalid value encountered in divide
model_cov_xy / np.sqrt(model_var_x * model_var_y),
[27]:
| func_node | output_node | mse | r2 | r | has_edge | p_value | model_r2 | model_r | model_mse | r2_gain | r_gain | mse_gain | q_value | within_output_rank | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | func1 | out3 | 0.185279 | 0.823894 | 0.919120 | False | 2.086947e-11 | 0.00000 | 0.00000 | 1.120165 | 0.823894 | 0.919120 | -0.934886 | 4.173895e-10 | 1 |
| 1 | func3 | out2 | 0.047971 | 0.948791 | 0.974521 | False | 2.940781e-09 | 0.00000 | 0.00000 | 0.958270 | 0.948791 | 0.974521 | -0.910299 | 2.940781e-08 | 1 |
| 2 | func4 | out2 | 0.121563 | 0.870232 | 0.933362 | False | 4.913948e-09 | 0.00000 | 0.00000 | 0.958270 | 0.870232 | 0.933362 | -0.836707 | 3.275965e-08 | 2 |
| 3 | func4 | out1 | 0.369538 | 0.628792 | 0.815660 | False | 3.834745e-08 | 0.00000 | 0.00000 | 1.084873 | 0.628792 | 0.815660 | -0.715335 | 1.917373e-07 | 1 |
| 4 | func1 | out1 | 0.448407 | 0.549566 | 0.753065 | False | 1.199566e-07 | 0.00000 | 0.00000 | 1.084873 | 0.549566 | 0.753065 | -0.636466 | 4.798264e-07 | 2 |
| 5 | func3 | out1 | 0.445066 | 0.552922 | 0.767172 | False | 2.191132e-06 | 0.00000 | 0.00000 | 1.084873 | 0.552922 | 0.767172 | -0.639807 | 7.303773e-06 | 3 |
| 6 | func0 | out2 | 0.425644 | 0.545625 | 0.738807 | False | 1.920007e-05 | 0.00000 | 0.00000 | 0.958270 | 0.545625 | 0.738807 | -0.532625 | 5.485734e-05 | 3 |
| 7 | func4 | out3 | 0.678931 | 0.354680 | 0.629966 | False | 3.609376e-05 | 0.00000 | 0.00000 | 1.120165 | 0.354680 | 0.629966 | -0.441234 | 9.023440e-05 | 2 |
| 8 | func2 | out2 | 0.484007 | 0.483323 | 0.701398 | False | 1.499416e-04 | 0.00000 | 0.00000 | 0.958270 | 0.483323 | 0.701398 | -0.474262 | 2.884546e-04 | 4 |
| 9 | func1 | out2 | 0.505211 | 0.460688 | 0.685677 | False | 1.389110e-04 | 0.00000 | 0.00000 | 0.958270 | 0.460688 | 0.685677 | -0.453059 | 2.884546e-04 | 5 |
| 10 | func0 | out1 | 0.687283 | 0.309610 | 0.609056 | False | 1.586500e-04 | 0.00000 | 0.00000 | 1.084873 | 0.309610 | 0.609056 | -0.397590 | 2.884546e-04 | 4 |
| 11 | func2 | out3 | 0.746067 | 0.290868 | 0.580942 | False | 6.815447e-04 | 0.00000 | 0.00000 | 1.120165 | 0.290868 | 0.580942 | -0.374098 | 1.135908e-03 | 3 |
| 12 | func3 | out3 | 0.706217 | 0.328745 | 0.616028 | False | 1.407202e-03 | 0.00000 | 0.00000 | 1.120165 | 0.328745 | 0.616028 | -0.413948 | 2.164925e-03 | 4 |
| 13 | func2 | out1 | 0.801581 | 0.194795 | 0.520465 | False | 6.887123e-03 | 0.00000 | 0.00000 | 1.084873 | 0.194795 | 0.520465 | -0.283292 | 9.838747e-03 | 5 |
| 14 | func3 | out0 | 0.046704 | 0.949335 | 0.975087 | True | 6.918867e-01 | 0.95071 | 0.97506 | 0.045436 | -0.001375 | 0.000027 | 0.001268 | 9.225157e-01 | 1 |
| 15 | func0 | out3 | 1.154114 | 0.000000 | -0.070084 | False | 7.984631e-01 | 0.00000 | 0.00000 | 1.120165 | 0.000000 | -0.070084 | 0.033949 | 9.980789e-01 | 5 |
| 16 | func4 | out0 | 0.101463 | 0.889932 | 0.944160 | False | 9.991888e-01 | 0.95071 | 0.97506 | 0.045436 | -0.060778 | -0.030900 | 0.056026 | 1.000000e+00 | 2 |
| 17 | func0 | out0 | 0.442974 | 0.519457 | 0.720759 | False | 1.000000e+00 | 0.95071 | 0.97506 | 0.045436 | -0.431254 | -0.254301 | 0.397538 | 1.000000e+00 | 3 |
| 18 | func2 | out0 | 0.495658 | 0.462304 | 0.684795 | False | 1.000000e+00 | 0.95071 | 0.97506 | 0.045436 | -0.488406 | -0.290265 | 0.450222 | 1.000000e+00 | 4 |
| 19 | func1 | out0 | 0.507531 | 0.449424 | 0.678707 | False | 1.000000e+00 | 0.95071 | 0.97506 | 0.045436 | -0.501286 | -0.296353 | 0.462095 | 1.000000e+00 | 5 |
[ ]:
res_missing = res[lambda x: (x.within_output_rank == 1) & (x.has_edge == False)]
res_missing_edges = [(i,j) for i,j in res_missing[['func_node', 'output_node']].values]
res_missing_edges # inferred edges
[('func1', 'out3'), ('func3', 'out2'), ('func4', 'out1')]
[ ]:
remove # true removed
[('func4', 'out1'), ('func3', 'out2'), ('func1', 'out3')]