gsnn.models.utils

Functions

bootstrap_r(y, yhat[, multioutput, n, ...])

To get a better estimate of validation performance, we compute the validation 95% confidence interval of average pearson correlation.

compute_sample_weights(sig_ids[, ...])

Calculate the sample weights based on the joint frequency of cell lines and perturbation IDs.

corr_score(y, yhat[, multioutput, method, eps])

calculate the average pearson correlation score

get_activation(act)

get_crit(crit)

get_optim(optim)

get_regressed_r(y, yhat, sig_ids, vars[, ...])

get_scheduler(optim, args, loader)

get_sigid_attrs(sig_ids)

Extract cell line names (cell_inames) and perturbation IDs (pert_id) from the given signature IDs (sig_ids).

next_divisor(N, X)

returns the smallest divisor of N which is larger than or equal to X

predict_gnn(loader, model, device[, verbose])

predict_gsnn(loader, model, device[, verbose])

predict_nn(loader, model, device[, verbose])

randomize(data)

regress_out(y, df, vars)

regress out variance from certain variables

Classes

Counter([iterable])

Dict subclass for counting hashable items.

TBLogger(root)

class gsnn.models.utils.TBLogger(root)[source]

Bases: object

add_hparam_results(args, model, data, device, test_loader, val_loader, siginfo, time_elapsed, epoch)[source]
log(epoch, train_metrics, val_metrics)[source]

train_loss = train_metrics.get(‘loss’, None) val_r2 = val_metrics.get(‘r2’, None) val_r_flat = val_metrics.get(‘r_flat’, None) val_mse = val_metrics.get(‘mse’, None)

if train_loss is not None:

self.writer.add_scalar(‘train-loss’, train_loss, epoch)

if val_r2 is not None:

self.writer.add_scalar(‘val-r2’, val_r2, epoch)

if val_r_flat is not None:

self.writer.add_scalar(‘val-corr-flat’, val_r_flat, epoch)

if val_mse is not None:

self.writer.add_scalar(‘val-mse’, val_mse, epoch)

gsnn.models.utils.bootstrap_r(y, yhat, multioutput='uniform_weighted', n=100, q_lower=0.025, q_upper=0.975)[source]

To get a better estimate of validation performance, we compute the validation 95% confidence interval of average pearson correlation.

Parameters:
  • values (yhat np.array predicted) –

  • values

  • [uniform_weighted (multioutput str method to handle multioutput prediction) –

  • raw_values]

  • compute (n int number of bootstrapped samples to) –

  • quantile (q_upper float upper bound) –

  • quantile

Returns:

r_low, r_up the lower and upper quantile of the (average) pearson correlation of y,yhat

gsnn.models.utils.compute_sample_weights(sig_ids, max_prob_fold_diff=100)[source]

Calculate the sample weights based on the joint frequency of cell lines and perturbation IDs.

Parameters:
  • cell_lines (numpy.ndarray) – A numpy array containing the cell line for each example.

  • pert_ids (numpy.ndarray) – A numpy array containing the perturbation ID for each example.

Returns:

A PyTorch tensor containing the sample weights for each example.

Return type:

torch.Tensor

gsnn.models.utils.corr_score(y, yhat, multioutput='uniform_weighted', method='pearson', eps=1e-06)[source]

calculate the average pearson correlation score

y (n_samples, n_outputs): yhat (n_samples, n_outputs):

gsnn.models.utils.get_activation(act)[source]
gsnn.models.utils.get_crit(crit)[source]
gsnn.models.utils.get_optim(optim)[source]
gsnn.models.utils.get_regressed_r(y, yhat, sig_ids, vars, data='../../data/', multioutput='uniform_weighted', siginfo=None)[source]
gsnn.models.utils.get_scheduler(optim, args, loader)[source]
gsnn.models.utils.get_sigid_attrs(sig_ids)[source]

Extract cell line names (cell_inames) and perturbation IDs (pert_id) from the given signature IDs (sig_ids).

Parameters:

sig_ids (list or array-like) – A list or array of signature IDs to be parsed.

Returns:

A tuple containing two lists:
  • cell_inames (list): Cell line names corresponding to the input signature IDs.

  • pert_ids (list): Perturbation IDs corresponding to the input signature IDs.

Return type:

tuple

gsnn.models.utils.next_divisor(N, X)[source]

returns the smallest divisor of N which is larger than or equal to X

gsnn.models.utils.predict_gnn(loader, model, device, verbose=True)[source]
gsnn.models.utils.predict_gsnn(loader, model, device, verbose=True)[source]
gsnn.models.utils.predict_nn(loader, model, device, verbose=True)[source]
gsnn.models.utils.randomize(data)[source]
gsnn.models.utils.regress_out(y, df, vars)[source]

regress out variance from certain variables

inputs

y numpy array signal to modify df dataframe co-variates options vars list<str> variables to regress out; must be columns in dataframe

outputs

numpy array augmented y signal