gsnn.optim.Environment
Functions
|
Calculate the Exponential Moving Average (EMA) |
Classes
|
|
|
|
|
- class gsnn.optim.Environment.Environment(action_edge_dict, train_dataset, val_dataset, model_kwargs, training_kwargs, metric='spearman', reward_type='auc', verbose=True, raise_error_on_fail=False, model=<class 'gsnn.models.GSNN.GSNN'>)[source]
Bases:
object- augment_edge_index(action)[source]
Augment
edge_index_dictaccording to the sampled action.- Parameters:
action (torch.Tensor, shape (n_actions,)) – Binary vector sampled from the policy that decides which edges to keep (1) or drop (0).
- gsnn.optim.Environment.ema(values, alpha=0.5)[source]
Calculate the Exponential Moving Average (EMA)
Args: values (list or array-like): A list of numerical values for which the EMA is to be calculated. alpha (float): The smoothing factor, a value between 0 and 1.
Returns: float: The EMA value for the current epoch.