Source code for gsnn.optim.REINFORCE

import numpy as np
import torch
from torch.distributions.bernoulli import Bernoulli
from sklearn.metrics import roc_auc_score

[docs]class REINFORCE(torch.nn.Module): def __init__(self, env, n_actions, action_labels=None, clip=10, eps=1e-5, warmup=3, verbose=True, entropy=0., entropy_decay=0.99, min_entropy=0.01, window=10, init_prob=0.9, lr=1e-2, policy_decay=0.): """ REINFORCE algorithm for optimizing graph structure. The REINFORCE algorithm is used to learn an optimal graph structure by treating edge selection as a reinforcement learning problem. Each edge is treated as a binary action (include/exclude) and the model is trained to maximize expected reward. Args: env (Environment): Environment object that handles model training and evaluation n_actions (int): Number of binary actions (edges) to optimize action_labels (array, optional): Ground truth binary labels for actions. Used for evaluation. clip (float, optional): Clipping value for reward normalization. Default: 10 eps (float, optional): Small constant for numerical stability. Default: 1e-5 warmup (int, optional): Number of warmup iterations before policy updates. Default: 3 verbose (bool, optional): Whether to print progress. Default: True entropy (float, optional): Initial entropy coefficient. Default: 0. entropy_decay (float, optional): Decay rate for entropy coefficient. Default: 0.99 min_entropy (float, optional): Minimum entropy coefficient. Default: 0.01 window (int, optional): Window size for reward normalization. Default: 10 init_prob (float, optional): Initial probability for edge selection. Default: 0.9 lr (float, optional): Learning rate for policy optimization. Default: 1e-2 policy_decay (float, optional): L1 regularization coefficient for policy. Default: 0. Example: >>> env = Environment(action_edge_dict, train_dataset, test_dataset, model_kwargs, training_kwargs) >>> reinforce = REINFORCE(env, n_actions=10, clip=10, entropy=0.1) >>> for i in range(100): >>> reinforce.step() >>> best_action = reinforce.best_action """ super().__init__() self.action_labels = action_labels self.env = env self.n_actions = n_actions self.entropy = entropy self.entropy_decay = entropy_decay self.min_entropy = min_entropy self.clip = clip self.eps = eps self.warmup = warmup self.rewards = [] self.actions = [] self.iteration = 0 self.verbose = verbose self.window = window self.policy_decay = policy_decay # need to convert init prob to logit value init_logit = np.log(init_prob / (1 - init_prob)) self.logits = torch.nn.Parameter(torch.ones((1, self.n_actions), dtype=torch.float32) * init_logit) self.optim = torch.optim.Adam([self.logits], lr=lr) self.best_reward = None self.best_action = None
[docs] def sample(self): policy = Bernoulli(logits=self.logits) action = policy.sample() return action
[docs] def update(self, rewards, actions=None): self.rewards.append(rewards) # Keep a history of sampled actions for post-hoc analyses. if actions is not None: self.actions.append(actions.detach().cpu().numpy()) self.entropy = float(max(self.entropy * self.entropy_decay, self.min_entropy)) if self.verbose: print(f'entropy value -> {self.entropy:.3f}', end='\r')
[docs] def get_reward_params(self): if len(self.rewards) < self.warmup: return 0,1 else: rewards_ = np.stack(self.rewards[-self.window:], axis=0) return rewards_.mean(0), rewards_.std(0)
[docs] def scale(self, rewards): # rewards shape (n_outputs) mu, std = self.get_reward_params() rewards = (rewards - mu) / (std + self.eps) rewards = np.clip(rewards, -self.clip, self.clip) return rewards
[docs] def prob_of(self, action): policy = Bernoulli(logits=self.logits) return torch.exp(policy.log_prob(action).sum())
[docs] def get_edge_probs(self): return torch.sigmoid(self.logits).detach().numpy()
[docs] def print_progress_(self): if self.action_labels is not None: edge_probs = self.logits.squeeze().sigmoid().detach().cpu().numpy() true_action = self.action_labels auroc = roc_auc_score(true_action, edge_probs) acc = ((edge_probs > 0.5) == true_action).mean() prob_true = self.prob_of(torch.from_numpy(true_action)) print(f'\t --> iter: {self.iteration} || auroc {auroc:0.3f} || acc: {acc:.3f} || prob(true_action): {prob_true:.3E} || last reward: {self.rewards[-1]:.3f}') else: print(f'\t --> iter: {self.iteration} || last reward: {self.rewards[-1]:.3f}')
[docs] def step(self): self.optim.zero_grad() policy = Bernoulli(logits=self.logits) action = policy.sample() rewards = self.env.run(action) advantages = self.scale(rewards).mean() if len(self.rewards) >= self.warmup: loss = -(policy.log_prob(action) * advantages).sum() - self.entropy * policy.entropy().sum() + self.policy_decay*self.logits.sigmoid().mean() loss.backward() self.optim.step() self.update(rewards, action) # log best reward if (self.best_reward is None) or (rewards > self.best_reward): self.best_reward = rewards self.best_action = action self.iteration += 1 self.print_progress_()