import torch
import torch_geometric as pyg
from sklearn.metrics import r2_score
import numpy as np
import copy
from torch.utils.data import DataLoader
import gc
from gsnn.models.GSNN import GSNN
from gsnn.optim.EarlyStopper import EarlyStopper
from gsnn.models import utils
#import torch._dynamo
#torch._dynamo.config.suppress_errors = True
#torch._dynamo.config.cache_size_limit = 100000
[docs]def ema(values, alpha=2/(3+1)):
"""
Calculate the Exponential Moving Average (EMA)
Args:
values (list or array-like): A list of numerical values for which the EMA is to be calculated.
alpha (float): The smoothing factor, a value between 0 and 1.
Returns:
float: The EMA value for the current epoch.
"""
if not values:
return None
# Initialize EMA with the first value
val = values[0]
# Calculate EMA up to the current epoch (last value in the list)
for t in range(1, len(values)):
val = alpha * values[t] + (1 - alpha) * val
return val
[docs]class Environment():
def __init__(self, action_edge_dict, train_dataset, val_dataset, model_kwargs,
training_kwargs, metric='spearman', reward_type='auc', verbose=True,
raise_error_on_fail=False, model=GSNN):
"""
Environment for training and evaluating graph neural networks with different edge configurations.
This environment handles the training and evaluation of graph neural networks with different edge
configurations. It is used in conjunction with the REINFORCE algorithm to optimize graph structure.
Args:
action_edge_dict (dict): Dictionary that maps each edge type (key) to **a 1-D tensor/list with one entry per edge**.
The entries need **not** be unique across different keys – a fresh, global index is
allocated internally so that each edge receives its own independent Bernoulli
action. Use the special value ``-1`` for edges that should always be kept (these
share a single, constant action that is fixed to 1). The total number of learnable
actions is inferred automatically.
train_dataset (Dataset): Training dataset
val_dataset (Dataset): Validation dataset
model_kwargs (dict): Model configuration parameters
training_kwargs (dict): Training configuration parameters
metric (str, optional): Metric for model evaluation. One of ['spearman', 'mse', 'pearson', 'r2']. Default: 'spearman'
reward_type (str, optional): Type of reward signal. One of ['auc', 'best', 'last']. Default: 'auc'
verbose (bool, optional): Whether to print training progress. Default: True
raise_error_on_fail (bool, optional): Whether to raise errors on training failures. Default: False
model (class, optional): Model class to use. Default: gsnn.models.GSNN
Example:
>>> action_edge_dict = {('input', 'to', 'function'): torch.arange(n_edges)}
>>> env = Environment(action_edge_dict, train_dataset, val_dataset,
... model_kwargs={'channels': 64},
... training_kwargs={'lr': 0.01})
>>> action = torch.ones(n_edges)
>>> reward = env.run(action)
"""
self.action_edge_dict = action_edge_dict
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.model_kwargs = model_kwargs
self.training_kwargs = training_kwargs
self.metric = metric
self.reward_type = reward_type
self.verbose = verbose
self.raise_error_on_fail = raise_error_on_fail
self._model = model
self.edge_index_dict = model_kwargs['edge_index_dict']
self.model_kwargs.pop('edge_index_dict', None)
self.N_func = len(model_kwargs['node_names_dict']['function'])
self.E_func = self.edge_index_dict['function', 'to','function'].size(1)
# ------------------------------------------------------------------
# Build a *unique* global index for every optimisable edge.
# Each key in ``action_edge_dict`` can list indices for its edges. These
# indices no longer need to be unique across keys – we assign a fresh
# global index sequentially so that every edge has its own independent
# Bernoulli action. Indices with value ``-1`` are interpreted as
# *fixed* (always-on) edges and share a single global index that is
# forced to 1 during optimisation.
# ------------------------------------------------------------------
self.key_action_index = {}
_offset = 0
for _key, _idxs in self.action_edge_dict.items():
# Make sure we can iterate regardless of the container type
_idxs_iter = _idxs.tolist() if torch.is_tensor(_idxs) else list(_idxs)
_map = []
for _i in _idxs_iter:
if _i == -1:
# placeholder, will be replaced by fixed index later
_map.append(-1)
else:
_map.append(_offset)
_offset += 1
self.key_action_index[_key] = torch.tensor(_map, dtype=torch.long)
# Index of the constant "always-on" edge that backs all -1 entries
self.fixed_action_index = _offset
# Replace the -1 placeholders by the fixed index
for _key in self.key_action_index:
mask_fixed = self.key_action_index[_key] == -1
self.key_action_index[_key][mask_fixed] = self.fixed_action_index
# Number of learnable actions (excludes the always-on sentinel)
self.n_actions = _offset
[docs] def augment_edge_index(self, action):
"""Augment ``edge_index_dict`` according to the sampled *action*.
Parameters
----------
action : torch.Tensor, shape (n_actions,)
Binary vector sampled from the policy that decides which edges to
keep (1) or drop (0).
"""
# Append the constant *always-on* action that backs all fixed edges.
action = action.squeeze()
if action.dim() == 0:
action = action.unsqueeze(0)
action = torch.cat(
(action, torch.tensor([1.0], dtype=action.dtype, device=action.device)),
dim=0,
)
action_bool = action == 1 # convert to boolean mask once
# Map the global action vector to every edge type
edge_mask_dict = {
key: action_bool[idxs]
for key, idxs in self.key_action_index.items()
}
edge_index_dict_ = {}
for key, edge_index in self.edge_index_dict.items():
if key in edge_mask_dict:
edge_index_dict_[key] = edge_index[:, edge_mask_dict[key]]
else:
edge_index_dict_[key] = edge_index
return edge_index_dict_
[docs] def train(self, edge_index_dict):
# cuda mem accumulation problem
gc.collect()
torch.cuda.empty_cache()
device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
train_loader = DataLoader(self.train_dataset, batch_size=self.training_kwargs['batch'], num_workers=self.training_kwargs['workers'], shuffle=True, persistent_workers=True)
val_loader = DataLoader(self.val_dataset, batch_size=self.training_kwargs['batch'], num_workers=self.training_kwargs['workers'], shuffle=False, persistent_workers=True)
model = self._model(edge_index_dict=edge_index_dict, **self.model_kwargs).to(device)
optim = torch.optim.Adam(model.parameters(), lr=self.training_kwargs['lr'])
crit = torch.nn.MSELoss()
best_mean_val = -np.inf
best_val_score = None
scores = []
for epoch in range(self.training_kwargs['max_epochs']):
model.train()
for i, (x,y,*_) in enumerate(train_loader):
optim.zero_grad()
if x.size(0) == 1: continue # BUG workaround: if batch only has 1 obs it fails
x = x.to(device); y = y.to(device)
yhat = model(x)
loss = crit(yhat, y)
loss.backward()
if 'clip_grad' in self.training_kwargs:
if self.training_kwargs['clip_grad']:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optim.step()
if torch.isnan(loss):
del model; del optim; del crit; del train_loader; del val_loader
return -1
if self.verbose: print(f'[batch: {i+1}/{len(train_loader)}]', end='\r')
# validation perf
y,yhat,_ = utils.predict_gsnn(val_loader, model, device=device, verbose=False)
# since we have a multioutput prediction problem, we need to return multioutput performances
if self.metric == 'mse':
val_score = -np.mean((y-yhat)**2, dim=0)
elif self.metric == 'pearson':
val_score = utils.corr_score(y, yhat, method='pearson', multioutput='uniform_weighted')
elif self.metric == 'spearman':
val_score = utils.corr_score(y, yhat, method='spearman', multioutput='uniform_weighted')
elif self.metric == 'r2':
val_score = np.clip(utils.corr_score(y, yhat, method='r2', multioutput='uniform_weighted'), -1,1)
else:
raise Exception('unrecognized `metric` type')
scores.append(val_score)
if self.verbose: print(f'\t\trun progress: {epoch}/{self.training_kwargs["max_epochs"]} | train loss: {loss.item():.1f} || mean val perf: {val_score.mean():.3f}', end='\r')
# use the best val as the reward value
# could use running mean too...
if (val_score.mean() > best_mean_val):
best_val_score = val_score
# trying to find cuda mem accumulation
del model; del optim; del crit; del train_loader; del val_loader
if self.reward_type == 'best':
reward = best_val_score
elif self.reward_type == 'last':
reward = val_score
elif self.reward_type == 'auc':
reward = np.sum(np.stack(scores, axis=0), axis=0) # should reward longer training runs as well as high perf
else:
raise NotImplementedError('unrecognized reward type')
return reward
[docs] def validate(self, edge_index_dict):
'''
if critical nodes, such as func->output edges are not included, include them; should speed up convergence.
'''
if edge_index_dict['function', 'to', 'output'].size(1) == 0:
return False
else:
return True
[docs] def run(self, action):
# augment edge_index_dict appropriately
edge_index_dict = self.augment_edge_index(action=action.cpu())
if not self.validate(edge_index_dict): return -1
# train model
try:
reward = self.train(edge_index_dict)
except:
# failed trials will result in low reward; e.g., nan divergences
reward = -1
if self.raise_error_on_fail: raise
return reward