gsnn.models.utils
Functions
|
To get a better estimate of validation performance, we compute the validation 95% confidence interval of average pearson correlation. |
|
Calculate the sample weights based on the joint frequency of cell lines and perturbation IDs. |
|
calculate the average pearson correlation score |
|
|
|
|
|
|
|
|
|
|
|
Extract cell line names (cell_inames) and perturbation IDs (pert_id) from the given signature IDs (sig_ids). |
|
returns the smallest divisor of N which is larger than or equal to X |
|
|
|
|
|
|
|
|
|
regress out variance from certain variables |
Classes
|
Dict subclass for counting hashable items. |
|
- class gsnn.models.utils.TBLogger(root)[source]
Bases:
object- add_hparam_results(args, model, data, device, test_loader, val_loader, siginfo, time_elapsed, epoch)[source]
- log(epoch, train_metrics, val_metrics)[source]
train_loss = train_metrics.get(‘loss’, None) val_r2 = val_metrics.get(‘r2’, None) val_r_flat = val_metrics.get(‘r_flat’, None) val_mse = val_metrics.get(‘mse’, None)
- if train_loss is not None:
self.writer.add_scalar(‘train-loss’, train_loss, epoch)
- if val_r2 is not None:
self.writer.add_scalar(‘val-r2’, val_r2, epoch)
- if val_r_flat is not None:
self.writer.add_scalar(‘val-corr-flat’, val_r_flat, epoch)
- if val_mse is not None:
self.writer.add_scalar(‘val-mse’, val_mse, epoch)
- gsnn.models.utils.bootstrap_r(y, yhat, multioutput='uniform_weighted', n=100, q_lower=0.025, q_upper=0.975)[source]
To get a better estimate of validation performance, we compute the validation 95% confidence interval of average pearson correlation.
- Parameters:
values (yhat np.array predicted) –
values –
[uniform_weighted (multioutput str method to handle multioutput prediction) –
raw_values] –
compute (n int number of bootstrapped samples to) –
quantile (q_upper float upper bound) –
quantile –
- Returns:
r_low, r_up the lower and upper quantile of the (average) pearson correlation of y,yhat
- gsnn.models.utils.compute_sample_weights(sig_ids, max_prob_fold_diff=100)[source]
Calculate the sample weights based on the joint frequency of cell lines and perturbation IDs.
- Parameters:
cell_lines (numpy.ndarray) – A numpy array containing the cell line for each example.
pert_ids (numpy.ndarray) – A numpy array containing the perturbation ID for each example.
- Returns:
A PyTorch tensor containing the sample weights for each example.
- Return type:
- gsnn.models.utils.corr_score(y, yhat, multioutput='uniform_weighted', method='pearson', eps=1e-06)[source]
calculate the average pearson correlation score
y (n_samples, n_outputs): yhat (n_samples, n_outputs):
- gsnn.models.utils.get_regressed_r(y, yhat, sig_ids, vars, data='../../data/', multioutput='uniform_weighted', siginfo=None)[source]
- gsnn.models.utils.get_sigid_attrs(sig_ids)[source]
Extract cell line names (cell_inames) and perturbation IDs (pert_id) from the given signature IDs (sig_ids).
- Parameters:
sig_ids (list or array-like) – A list or array of signature IDs to be parsed.
- Returns:
- A tuple containing two lists:
cell_inames (list): Cell line names corresponding to the input signature IDs.
pert_ids (list): Perturbation IDs corresponding to the input signature IDs.
- Return type: