gsnn.optim.Environment

Functions

ema(values[, alpha])

Calculate the Exponential Moving Average (EMA)

Classes

EarlyStopper([patience, min_delta])

Environment(action_edge_dict, train_dataset, ...)

GSNN(*args, **kwargs)

class gsnn.optim.Environment.Environment(action_edge_dict, train_dataset, val_dataset, model_kwargs, training_kwargs, metric='spearman', reward_type='auc', verbose=True, raise_error_on_fail=False, model=<class 'gsnn.models.GSNN.GSNN'>)[source]

Bases: object

augment_edge_index(action)[source]

Augment edge_index_dict according to the sampled action.

Parameters:

action (torch.Tensor, shape (n_actions,)) – Binary vector sampled from the policy that decides which edges to keep (1) or drop (0).

run(action)[source]
train(edge_index_dict)[source]
validate(edge_index_dict)[source]

if critical nodes, such as func->output edges are not included, include them; should speed up convergence.

gsnn.optim.Environment.ema(values, alpha=0.5)[source]

Calculate the Exponential Moving Average (EMA)

Args: values (list or array-like): A list of numerical values for which the EMA is to be calculated. alpha (float): The smoothing factor, a value between 0 and 1.

Returns: float: The EMA value for the current epoch.