This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def prediction2label(pred: np.ndarray): | |
"""Convert ordinal predictions to class labels, e.g. | |
[0.9, 0.1, 0.1, 0.1] -> 0 | |
[0.9, 0.9, 0.1, 0.1] -> 1 | |
[0.9, 0.9, 0.9, 0.1] -> 2 | |
etc. | |
""" | |
return (pred > 0.5).cumprod(axis=1).sum(axis=1) - 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def ordinal_regression(predictions: List[List[float]], targets: List[float]): | |
"""Ordinal regression with encoding as in https://arxiv.org/pdf/0704.1028.pdf""" | |
# Create out modified target with [batch_size, num_labels] shape | |
modified_target = torch.zeros_like(predictions) | |
# Fill in ordinal target function, i.e. 0 -> [1,0,0,...] | |
for i, target in enumerate(targets): | |
modified_target[i, 0:target+1] = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cross_entropy(predictions: List[List[float]], targets: List[float]): | |
return nn.CrossEntropyLoss(reduction='none')(predictions, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python train.py \ | |
--data_path data/solubility_dataset.csv \ | |
--dataset_type multiclass \ | |
--save_dir solubility_checkpoints/ \ | |
--ensemble 1 --num_folds 5 --epochs 50 \ | |
--split_type scaffold_balanced \ | |
--multiclass_num_classes 5 \ | |
--save_smiles_splits --save_preds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train a logSolubility model | |
python train.py \ | |
--data_path data/delaney.csv \ | |
--dataset_type regression \ | |
--save_dir delaney_checkpoints \ | |
--ensemble 3 --num_folds 10 --epochs 50 | |
# Get bayes ensemble grad results | |
python interpret_local.py \ | |
--test_path data/delaney_subset.csv \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
chemprop_predict \ | |
--test_path data/clintox_test.csv \ | |
--checkpoint_dir model_checkpoint \ | |
--features_generator rdkit_2d_normalized --no_features_scaling \ | |
--preds_path data/predictions.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
chemprop_train \ | |
--data_path data/clintox_train.csv \ | |
--config_path data/config.json \ | |
--dataset_type classification \ | |
--save_dir model_checkpoint \ | |
--num_folds 5 \ | |
--ensemble_size 3 \ | |
--features_generator rdkit_2d_normalized --no_features_scaling \ | |
--split_type scaffold_balanced |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
chemprop_hyperopt \ | |
--data_path data/clintox_train.csv \ | |
--dataset_type classification \ | |
--num_iters 50 \ | |
--features_generator rdkit_2d_normalized --no_features_scaling \ | |
--config_save_path data/config.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
apiVersion: v1 | |
kind: Service | |
metadata: | |
labels: | |
app: [VAR_MODULE_NAME] | |
name: [VAR_MODULE_NAME] | |
spec: | |
ports: | |
- name: predict | |
port: 5000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
def deploy_bentoml(model_name): | |
"""Deploys a BentoML from dockerhub into Kubernetes cluster""" | |
# Open a template for BentoML deployments | |
with open('bentoml_deploy.tpl', 'r') as fi: | |
# Substitute the name of the model into template & save | |
yaml_file = fi.read() |
NewerOlder