Source code for gsnn.optim.Environment




import torch 
import torch_geometric as pyg
from sklearn.metrics import r2_score 
import numpy as np
import copy 
from torch.utils.data import DataLoader
import gc 

from gsnn.models.GSNN import GSNN
from gsnn.optim.EarlyStopper import EarlyStopper
from gsnn.models import utils

#import torch._dynamo
#torch._dynamo.config.suppress_errors = True
#torch._dynamo.config.cache_size_limit = 100000


[docs]def ema(values, alpha=2/(3+1)): """ Calculate the Exponential Moving Average (EMA) Args: values (list or array-like): A list of numerical values for which the EMA is to be calculated. alpha (float): The smoothing factor, a value between 0 and 1. Returns: float: The EMA value for the current epoch. """ if not values: return None # Initialize EMA with the first value val = values[0] # Calculate EMA up to the current epoch (last value in the list) for t in range(1, len(values)): val = alpha * values[t] + (1 - alpha) * val return val
[docs]class Environment(): def __init__(self, action_edge_dict, train_dataset, val_dataset, model_kwargs, training_kwargs, metric='spearman', reward_type='auc', verbose=True, raise_error_on_fail=False, model=GSNN): """ Environment for training and evaluating graph neural networks with different edge configurations. This environment handles the training and evaluation of graph neural networks with different edge configurations. It is used in conjunction with the REINFORCE algorithm to optimize graph structure. Args: action_edge_dict (dict): Dictionary that maps each edge type (key) to **a 1-D tensor/list with one entry per edge**. The entries need **not** be unique across different keys – a fresh, global index is allocated internally so that each edge receives its own independent Bernoulli action. Use the special value ``-1`` for edges that should always be kept (these share a single, constant action that is fixed to 1). The total number of learnable actions is inferred automatically. train_dataset (Dataset): Training dataset val_dataset (Dataset): Validation dataset model_kwargs (dict): Model configuration parameters training_kwargs (dict): Training configuration parameters metric (str, optional): Metric for model evaluation. One of ['spearman', 'mse', 'pearson', 'r2']. Default: 'spearman' reward_type (str, optional): Type of reward signal. One of ['auc', 'best', 'last']. Default: 'auc' verbose (bool, optional): Whether to print training progress. Default: True raise_error_on_fail (bool, optional): Whether to raise errors on training failures. Default: False model (class, optional): Model class to use. Default: gsnn.models.GSNN Example: >>> action_edge_dict = {('input', 'to', 'function'): torch.arange(n_edges)} >>> env = Environment(action_edge_dict, train_dataset, val_dataset, ... model_kwargs={'channels': 64}, ... training_kwargs={'lr': 0.01}) >>> action = torch.ones(n_edges) >>> reward = env.run(action) """ self.action_edge_dict = action_edge_dict self.train_dataset = train_dataset self.val_dataset = val_dataset self.model_kwargs = model_kwargs self.training_kwargs = training_kwargs self.metric = metric self.reward_type = reward_type self.verbose = verbose self.raise_error_on_fail = raise_error_on_fail self._model = model self.edge_index_dict = model_kwargs['edge_index_dict'] self.model_kwargs.pop('edge_index_dict', None) self.N_func = len(model_kwargs['node_names_dict']['function']) self.E_func = self.edge_index_dict['function', 'to','function'].size(1) # ------------------------------------------------------------------ # Build a *unique* global index for every optimisable edge. # Each key in ``action_edge_dict`` can list indices for its edges. These # indices no longer need to be unique across keys – we assign a fresh # global index sequentially so that every edge has its own independent # Bernoulli action. Indices with value ``-1`` are interpreted as # *fixed* (always-on) edges and share a single global index that is # forced to 1 during optimisation. # ------------------------------------------------------------------ self.key_action_index = {} _offset = 0 for _key, _idxs in self.action_edge_dict.items(): # Make sure we can iterate regardless of the container type _idxs_iter = _idxs.tolist() if torch.is_tensor(_idxs) else list(_idxs) _map = [] for _i in _idxs_iter: if _i == -1: # placeholder, will be replaced by fixed index later _map.append(-1) else: _map.append(_offset) _offset += 1 self.key_action_index[_key] = torch.tensor(_map, dtype=torch.long) # Index of the constant "always-on" edge that backs all -1 entries self.fixed_action_index = _offset # Replace the -1 placeholders by the fixed index for _key in self.key_action_index: mask_fixed = self.key_action_index[_key] == -1 self.key_action_index[_key][mask_fixed] = self.fixed_action_index # Number of learnable actions (excludes the always-on sentinel) self.n_actions = _offset
[docs] def augment_edge_index(self, action): """Augment ``edge_index_dict`` according to the sampled *action*. Parameters ---------- action : torch.Tensor, shape (n_actions,) Binary vector sampled from the policy that decides which edges to keep (1) or drop (0). """ # Append the constant *always-on* action that backs all fixed edges. action = action.squeeze() if action.dim() == 0: action = action.unsqueeze(0) action = torch.cat( (action, torch.tensor([1.0], dtype=action.dtype, device=action.device)), dim=0, ) action_bool = action == 1 # convert to boolean mask once # Map the global action vector to every edge type edge_mask_dict = { key: action_bool[idxs] for key, idxs in self.key_action_index.items() } edge_index_dict_ = {} for key, edge_index in self.edge_index_dict.items(): if key in edge_mask_dict: edge_index_dict_[key] = edge_index[:, edge_mask_dict[key]] else: edge_index_dict_[key] = edge_index return edge_index_dict_
[docs] def train(self, edge_index_dict): # cuda mem accumulation problem gc.collect() torch.cuda.empty_cache() device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu' train_loader = DataLoader(self.train_dataset, batch_size=self.training_kwargs['batch'], num_workers=self.training_kwargs['workers'], shuffle=True, persistent_workers=True) val_loader = DataLoader(self.val_dataset, batch_size=self.training_kwargs['batch'], num_workers=self.training_kwargs['workers'], shuffle=False, persistent_workers=True) model = self._model(edge_index_dict=edge_index_dict, **self.model_kwargs).to(device) optim = torch.optim.Adam(model.parameters(), lr=self.training_kwargs['lr']) crit = torch.nn.MSELoss() best_mean_val = -np.inf best_val_score = None scores = [] for epoch in range(self.training_kwargs['max_epochs']): model.train() for i, (x,y,*_) in enumerate(train_loader): optim.zero_grad() if x.size(0) == 1: continue # BUG workaround: if batch only has 1 obs it fails x = x.to(device); y = y.to(device) yhat = model(x) loss = crit(yhat, y) loss.backward() if 'clip_grad' in self.training_kwargs: if self.training_kwargs['clip_grad']: torch.nn.utils.clip_grad_norm_(model.parameters(), 1) optim.step() if torch.isnan(loss): del model; del optim; del crit; del train_loader; del val_loader return -1 if self.verbose: print(f'[batch: {i+1}/{len(train_loader)}]', end='\r') # validation perf y,yhat,_ = utils.predict_gsnn(val_loader, model, device=device, verbose=False) # since we have a multioutput prediction problem, we need to return multioutput performances if self.metric == 'mse': val_score = -np.mean((y-yhat)**2, dim=0) elif self.metric == 'pearson': val_score = utils.corr_score(y, yhat, method='pearson', multioutput='uniform_weighted') elif self.metric == 'spearman': val_score = utils.corr_score(y, yhat, method='spearman', multioutput='uniform_weighted') elif self.metric == 'r2': val_score = np.clip(utils.corr_score(y, yhat, method='r2', multioutput='uniform_weighted'), -1,1) else: raise Exception('unrecognized `metric` type') scores.append(val_score) if self.verbose: print(f'\t\trun progress: {epoch}/{self.training_kwargs["max_epochs"]} | train loss: {loss.item():.1f} || mean val perf: {val_score.mean():.3f}', end='\r') # use the best val as the reward value # could use running mean too... if (val_score.mean() > best_mean_val): best_val_score = val_score # trying to find cuda mem accumulation del model; del optim; del crit; del train_loader; del val_loader if self.reward_type == 'best': reward = best_val_score elif self.reward_type == 'last': reward = val_score elif self.reward_type == 'auc': reward = np.sum(np.stack(scores, axis=0), axis=0) # should reward longer training runs as well as high perf else: raise NotImplementedError('unrecognized reward type') return reward
[docs] def validate(self, edge_index_dict): ''' if critical nodes, such as func->output edges are not included, include them; should speed up convergence. ''' if edge_index_dict['function', 'to', 'output'].size(1) == 0: return False else: return True
[docs] def run(self, action): # augment edge_index_dict appropriately edge_index_dict = self.augment_edge_index(action=action.cpu()) if not self.validate(edge_index_dict): return -1 # train model try: reward = self.train(edge_index_dict) except: # failed trials will result in low reward; e.g., nan divergences reward = -1 if self.raise_error_on_fail: raise return reward