-
-
Save VladVin/4121720ff984693f9b8c34c29e47f652 to your computer and use it in GitHub Desktop.
ConPLex evaluation for protein-ligand and ligand-ligand matching
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from functools import partial\n", | |
"from multiprocessing import Pool\n", | |
"import multiprocessing as mp\n", | |
"from pathlib import Path\n", | |
"import os\n", | |
"import sys\n", | |
"\n", | |
"from Bio import Seq, SeqIO\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"from rdkit.Chem import AllChem\n", | |
"from rdkit import Chem\n", | |
"from rdkit.Chem.rdmolfiles import MolFromMol2File, MolToSmiles, MolFromSmiles\n", | |
"from sklearn.metrics import roc_auc_score\n", | |
"import torch\n", | |
"from torch.utils.data import DataLoader, Dataset\n", | |
"from tqdm.notebook import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sys.path.append(\"./ConPlex/\")\n", | |
"from conplex_dti.model.architectures import SimpleCoembedding\n", | |
"from conplex_dti.featurizer.molecule import MorganFeaturizer\n", | |
"from conplex_dti.featurizer.protein import ProtBertFeaturizer\n", | |
"\n", | |
"from conplex_dataset import ConPlexDrugDataset\n", | |
"from metrics import ef_score\n", | |
"from utils import read_actives_decoys, read_ligands, calc_euclidean_distance, calc_cosine_distance, run_benchmark" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = SimpleCoembedding()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"DEVICE = torch.device(\"mps\")\n", | |
"# DEVICE = torch.device(\"cuda:0\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ckpt = torch.load(\"./ConPlex/weights/BindingDB_ExperimentalValidModel.pt\", map_location=DEVICE)\n", | |
"model.load_state_dict(ckpt)\n", | |
"model.eval()\n", | |
"model = model.to(DEVICE)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"DATA_SOURCE = \"litpcba\"\n", | |
"# DATA_SOURCE = \"dude\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if DATA_SOURCE == \"litpcba\":\n", | |
" DATA_DIR = Path(\"./data/LIT-PCBA/full_data\")\n", | |
" OUTPUT_DIR = Path('./output/LIT-PCBA/full_conplex_feats')\n", | |
" BENCH_OUTPUT_PATH = Path('./output/LIT-PCBA/benchmark_full_conplex_prot2lig.csv')\n", | |
"elif DATA_SOURCE == \"dude\":\n", | |
" DATA_DIR = Path(\"./data/dud-e\")\n", | |
" OUTPUT_DIR = Path('./output/dud-e/conplex_feats')\n", | |
" BENCH_OUTPUT_PATH = Path('./output/dud-e/benchmark_conplex_prot2lig.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if DATA_SOURCE == \"litpcba\":\n", | |
" target_dir = DATA_DIR / \"ADRB2\"\n", | |
"elif DATA_SOURCE == \"dude\":\n", | |
" target_dir = DATA_DIR / \"abl1\"\n", | |
"\n", | |
"target_output_dir = OUTPUT_DIR / target_dir.stem\n", | |
"actives_smiles, decoys_smiles = read_actives_decoys(DATA_SOURCE, target_dir)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mol_featurizer = MorganFeaturizer(save_dir=target_output_dir)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([0., 1., 0., ..., 0., 0., 0.])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mol_featurizer.transform(actives_smiles[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" | |
] | |
} | |
], | |
"source": [ | |
"prot_featurizer = ProtBertFeaturizer(save_dir=target_output_dir)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mol_feats = mol_featurizer.transform(actives_smiles[0])\n", | |
"mol_embed = model.project_drug(mol_feats.unsqueeze(0).to(DEVICE))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def extract_fingerprints_for_ligands(proteins_all_sequences: list[list[Seq]], prot_featurizer, model, save_dir):\n", | |
" proteins_sequences = []\n", | |
" for protein_sequences in proteins_all_sequences:\n", | |
" joint_sequence = ''.join([str(seq) for seq in protein_sequences])\n", | |
" proteins_sequences.append(joint_sequence)\n", | |
" \n", | |
" proteins_features = []\n", | |
" for protein_sequence in tqdm(proteins_sequences):\n", | |
" protein_features = prot_featurizer.transform(protein_sequence)\n", | |
" proteins_features.append(protein_features)\n", | |
" proteins_features = torch.stack(proteins_features)\n", | |
" \n", | |
" batch_size = 8\n", | |
" fps = []\n", | |
" for i in range(0, len(proteins_features), batch_size):\n", | |
" batch = proteins_features[i:i+batch_size]\n", | |
" batch = batch.to(DEVICE)\n", | |
" with torch.no_grad():\n", | |
" prot_embed = model.project_target(batch)\n", | |
" fps.append(prot_embed.cpu().numpy())\n", | |
"\n", | |
" return np.concatenate(fps)\n", | |
"\n", | |
"def extract_fingerprints(smiles, model, save_dir):\n", | |
" dataset = ConPlexDrugDataset(smiles, save_dir=save_dir)\n", | |
" dataloader = DataLoader(dataset,\n", | |
" batch_size=512,\n", | |
" shuffle=False,\n", | |
" num_workers=os.cpu_count())\n", | |
" \n", | |
" fps = []\n", | |
" for batch in tqdm(dataloader, total=len(dataloader)):\n", | |
" batch = batch.to(DEVICE)\n", | |
" with torch.no_grad():\n", | |
" mol_embed = model.project_drug(batch)\n", | |
" fps.append(mol_embed.cpu().numpy())\n", | |
" \n", | |
" return np.concatenate(fps)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "fd7b4c5cea834a0aae3071036adcad49", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/1 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c98a81430c6145e4bb4875518c0a7f1c", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/611 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"actives_fps = extract_fingerprints(actives_smiles, model, target_output_dir)\n", | |
"decoys_fps = extract_fingerprints(decoys_smiles, model, target_output_dir)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def read_proteins(target_dir):\n", | |
" proteins_all_sequences = []\n", | |
" for entry_path in sorted(target_dir.glob('*_protein.fasta')):\n", | |
" proteins_all_sequences.append(SeqIO.parse(str(entry_path), 'fasta'))\n", | |
" \n", | |
" return proteins_all_sequences" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d5c30b774ad140729cf16d9fd21ab045", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/8 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(8, 1024)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# ligands_fps = extract_fingerprints([MolToSmiles(l) for l in ligands], model, target_output_dir)\n", | |
"# ligands_fps = np.array(ligands_fps)\n", | |
"\n", | |
"proteins_all_sequences = read_proteins(target_dir)\n", | |
"ligands_fps = extract_fingerprints_for_ligands(proteins_all_sequences, prot_featurizer, model, target_output_dir)\n", | |
"\n", | |
"ligands_fps.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"labels = np.concatenate([np.ones(len(actives_fps)), np.zeros(len(decoys_fps))]).astype(bool)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lig2act = calc_cosine_distance(ligands_fps, actives_fps)\n", | |
"lig2dec = calc_cosine_distance(ligands_fps, decoys_fps)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "183ef4357bb8446fb361b6282473b6d4", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/8 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"roc_aucs = []\n", | |
"for i in tqdm(range(lig2act.shape[0])):\n", | |
" roc_auc = roc_auc_score(labels, np.concatenate([lig2act[i], lig2dec[i]]))\n", | |
" roc_aucs.append(roc_auc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ROC AUC min: 0.467, max: 0.581, mean: 0.494, std: 0.034\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f'ROC AUC min: {np.min(roc_aucs):.3f}, max: {np.max(roc_aucs):.3f}, mean: {np.mean(roc_aucs):.3f}, std: {np.std(roc_aucs):.3f}')" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Protein-ligand matching" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"benchmark = {\n", | |
" 'target': [],\n", | |
" 'auroc_min': [],\n", | |
" 'auroc_max': [],\n", | |
" 'auroc_mean': [],\n", | |
" 'auroc_std': [],\n", | |
" 'auroc_fused': [],\n", | |
" 'ef1_min': [],\n", | |
" 'ef1_max': [],\n", | |
" 'ef1_mean': [],\n", | |
" 'ef1_std': [],\n", | |
" 'ef1_fused': []\n", | |
"}\n", | |
"\n", | |
"start_idx = None\n", | |
"similarity_fn = calc_cosine_distance\n", | |
"target_dirs = list(sorted(DATA_DIR.iterdir()))\n", | |
"for t, target_dir in enumerate(target_dirs):\n", | |
" print(f'{target_dir.stem} ({t+1}/{len(target_dirs)})')\n", | |
" target_output_dir = OUTPUT_DIR / target_dir.stem\n", | |
" \n", | |
" if start_idx is not None and t < start_idx:\n", | |
" if not target_output_dir.exists():\n", | |
" print(f'[{target_dir.stem}] No output dir found')\n", | |
" continue\n", | |
" \n", | |
" ligands_fps = np.load(target_output_dir / 'ligands_feats.npy')\n", | |
" actives_fps = np.load(target_output_dir / 'actives_feats.npy')\n", | |
" decoys_fps = np.load(target_output_dir / 'decoys_feats.npy')\n", | |
" else:\n", | |
" proteins_all_sequences = read_proteins(target_dir)\n", | |
"\n", | |
" actives_smiles, decoys_smiles = read_actives_decoys(DATA_SOURCE, target_dir)\n", | |
" actives_fps = extract_fingerprints(actives_smiles, model=model, save_dir=target_output_dir)\n", | |
" decoys_fps = extract_fingerprints(decoys_smiles, model=model, save_dir=target_output_dir)\n", | |
" target_output_dir.mkdir(exist_ok=True, parents=True)\n", | |
" np.save(target_output_dir / 'actives_feats.npy', actives_fps)\n", | |
" np.save(target_output_dir / 'decoys_feats.npy', decoys_fps)\n", | |
"\n", | |
" ligands_fps = extract_fingerprints_for_ligands(proteins_all_sequences, prot_featurizer, model, target_output_dir)\n", | |
" np.save(target_output_dir / 'ligands_feats.npy', ligands_fps)\n", | |
"\n", | |
" labels = np.concatenate([np.ones(len(actives_fps), dtype=bool), np.zeros(len(decoys_fps), dtype=bool)])\n", | |
" actives_fps = actives_fps.astype(np.float32)\n", | |
" decoys_fps = decoys_fps.astype(np.float32)\n", | |
" ligands_fps = ligands_fps.astype(np.float32)\n", | |
" lig2act = similarity_fn(ligands_fps, actives_fps)\n", | |
" lig2dec = similarity_fn(ligands_fps, decoys_fps)\n", | |
"\n", | |
" roc_aucs, ef1s = [], []\n", | |
" for i in tqdm(range(lig2act.shape[0])):\n", | |
" roc_auc = roc_auc_score(labels, np.concatenate([lig2act[i], lig2dec[i]]))\n", | |
" roc_aucs.append(roc_auc)\n", | |
" ef1 = ef_score(labels, np.concatenate([lig2act[i], lig2dec[i]]), 0.01)\n", | |
" ef1s.append(ef1)\n", | |
" roc_auc_fused = roc_auc_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)]))\n", | |
" ef1_fused = ef_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)]), 0.01)\n", | |
"\n", | |
" benchmark['target'].append(target_dir.stem)\n", | |
" benchmark['auroc_min'].append(np.min(roc_aucs))\n", | |
" benchmark['auroc_max'].append(np.max(roc_aucs))\n", | |
" benchmark['auroc_mean'].append(np.mean(roc_aucs))\n", | |
" benchmark['auroc_std'].append(np.std(roc_aucs))\n", | |
" benchmark['auroc_fused'].append(roc_auc_fused)\n", | |
" benchmark['ef1_min'].append(np.min(ef1s))\n", | |
" benchmark['ef1_max'].append(np.max(ef1s))\n", | |
" benchmark['ef1_mean'].append(np.mean(ef1s))\n", | |
" benchmark['ef1_std'].append(np.std(ef1s))\n", | |
" benchmark['ef1_fused'].append(ef1_fused)\n", | |
"\n", | |
" print(f'[{target_dir.stem}] ROC AUC min: {np.min(roc_aucs):.3f}, max: {np.max(roc_aucs):.3f}, '\n", | |
" f'mean: {np.mean(roc_aucs):.3f}, std: {np.std(roc_aucs):.3f} fused: {roc_auc_fused:.3f}')\n", | |
" print(f'[{target_dir.stem}] EF1 min: {np.min(ef1s):.3f}, max: {np.max(ef1s):.3f}, '\n", | |
" f'mean: {np.mean(ef1s):.3f}, std: {np.std(ef1s):.3f} fused: {ef1_fused:.3f}')\n", | |
" \n", | |
"benchmark = pd.DataFrame(benchmark)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>target</th>\n", | |
" <th>auroc_min</th>\n", | |
" <th>auroc_max</th>\n", | |
" <th>auroc_mean</th>\n", | |
" <th>auroc_std</th>\n", | |
" <th>auroc_fused</th>\n", | |
" <th>ef1_min</th>\n", | |
" <th>ef1_max</th>\n", | |
" <th>ef1_mean</th>\n", | |
" <th>ef1_std</th>\n", | |
" <th>ef1_fused</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>ADRB2</td>\n", | |
" <td>0.467131</td>\n", | |
" <td>0.580804</td>\n", | |
" <td>0.494027</td>\n", | |
" <td>0.034033</td>\n", | |
" <td>0.497650</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000e+00</td>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>ALDH1</td>\n", | |
" <td>0.525737</td>\n", | |
" <td>0.525737</td>\n", | |
" <td>0.525737</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.525737</td>\n", | |
" <td>1.632254</td>\n", | |
" <td>1.632254</td>\n", | |
" <td>1.632254</td>\n", | |
" <td>0.000000e+00</td>\n", | |
" <td>1.632254</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>ESR1_ago</td>\n", | |
" <td>0.431165</td>\n", | |
" <td>0.566769</td>\n", | |
" <td>0.455523</td>\n", | |
" <td>0.031804</td>\n", | |
" <td>0.560775</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>15.384615</td>\n", | |
" <td>1.025641</td>\n", | |
" <td>3.837597e+00</td>\n", | |
" <td>7.692308</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>ESR1_ant</td>\n", | |
" <td>0.463507</td>\n", | |
" <td>0.487013</td>\n", | |
" <td>0.473227</td>\n", | |
" <td>0.007324</td>\n", | |
" <td>0.471104</td>\n", | |
" <td>0.980392</td>\n", | |
" <td>4.901961</td>\n", | |
" <td>4.248366</td>\n", | |
" <td>9.912255e-01</td>\n", | |
" <td>3.921569</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>FEN1</td>\n", | |
" <td>0.578497</td>\n", | |
" <td>0.578497</td>\n", | |
" <td>0.578497</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.578497</td>\n", | |
" <td>2.439024</td>\n", | |
" <td>2.439024</td>\n", | |
" <td>2.439024</td>\n", | |
" <td>0.000000e+00</td>\n", | |
" <td>2.439024</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>GBA</td>\n", | |
" <td>0.484791</td>\n", | |
" <td>0.493707</td>\n", | |
" <td>0.487763</td>\n", | |
" <td>0.004203</td>\n", | |
" <td>0.488888</td>\n", | |
" <td>0.602410</td>\n", | |
" <td>1.204819</td>\n", | |
" <td>1.004016</td>\n", | |
" <td>2.839786e-01</td>\n", | |
" <td>0.602410</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>IDH1</td>\n", | |
" <td>0.529484</td>\n", | |
" <td>0.553243</td>\n", | |
" <td>0.545412</td>\n", | |
" <td>0.008663</td>\n", | |
" <td>0.545899</td>\n", | |
" <td>5.128205</td>\n", | |
" <td>5.128205</td>\n", | |
" <td>5.128205</td>\n", | |
" <td>1.776357e-15</td>\n", | |
" <td>5.128205</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>KAT2A</td>\n", | |
" <td>0.538885</td>\n", | |
" <td>0.543235</td>\n", | |
" <td>0.540335</td>\n", | |
" <td>0.002050</td>\n", | |
" <td>0.541592</td>\n", | |
" <td>0.515464</td>\n", | |
" <td>1.030928</td>\n", | |
" <td>0.687285</td>\n", | |
" <td>2.429920e-01</td>\n", | |
" <td>0.515464</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>MAPK1</td>\n", | |
" <td>0.482031</td>\n", | |
" <td>0.549280</td>\n", | |
" <td>0.518390</td>\n", | |
" <td>0.018859</td>\n", | |
" <td>0.548797</td>\n", | |
" <td>0.324675</td>\n", | |
" <td>7.142857</td>\n", | |
" <td>2.337662</td>\n", | |
" <td>2.104804e+00</td>\n", | |
" <td>7.142857</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>MTORC1</td>\n", | |
" <td>0.529677</td>\n", | |
" <td>0.573706</td>\n", | |
" <td>0.545634</td>\n", | |
" <td>0.017579</td>\n", | |
" <td>0.567938</td>\n", | |
" <td>1.030928</td>\n", | |
" <td>3.092784</td>\n", | |
" <td>1.499531</td>\n", | |
" <td>6.758297e-01</td>\n", | |
" <td>2.061856</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>10</th>\n", | |
" <td>OPRK1</td>\n", | |
" <td>0.423600</td>\n", | |
" <td>0.423600</td>\n", | |
" <td>0.423600</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.423600</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000e+00</td>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11</th>\n", | |
" <td>PKM2</td>\n", | |
" <td>0.517113</td>\n", | |
" <td>0.550826</td>\n", | |
" <td>0.540974</td>\n", | |
" <td>0.011080</td>\n", | |
" <td>0.537580</td>\n", | |
" <td>0.732601</td>\n", | |
" <td>0.915751</td>\n", | |
" <td>0.895401</td>\n", | |
" <td>5.755855e-02</td>\n", | |
" <td>0.915751</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>12</th>\n", | |
" <td>PPARG</td>\n", | |
" <td>0.330807</td>\n", | |
" <td>0.381156</td>\n", | |
" <td>0.353604</td>\n", | |
" <td>0.017436</td>\n", | |
" <td>0.362762</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000e+00</td>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>13</th>\n", | |
" <td>TP53</td>\n", | |
" <td>0.429233</td>\n", | |
" <td>0.429233</td>\n", | |
" <td>0.429233</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.429233</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.000000e+00</td>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>14</th>\n", | |
" <td>VDR</td>\n", | |
" <td>0.551515</td>\n", | |
" <td>0.551515</td>\n", | |
" <td>0.551515</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.551515</td>\n", | |
" <td>2.036199</td>\n", | |
" <td>2.036199</td>\n", | |
" <td>2.036199</td>\n", | |
" <td>0.000000e+00</td>\n", | |
" <td>2.036199</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" target auroc_min auroc_max auroc_mean auroc_std auroc_fused \\\n", | |
"0 ADRB2 0.467131 0.580804 0.494027 0.034033 0.497650 \n", | |
"1 ALDH1 0.525737 0.525737 0.525737 0.000000 0.525737 \n", | |
"2 ESR1_ago 0.431165 0.566769 0.455523 0.031804 0.560775 \n", | |
"3 ESR1_ant 0.463507 0.487013 0.473227 0.007324 0.471104 \n", | |
"4 FEN1 0.578497 0.578497 0.578497 0.000000 0.578497 \n", | |
"5 GBA 0.484791 0.493707 0.487763 0.004203 0.488888 \n", | |
"6 IDH1 0.529484 0.553243 0.545412 0.008663 0.545899 \n", | |
"7 KAT2A 0.538885 0.543235 0.540335 0.002050 0.541592 \n", | |
"8 MAPK1 0.482031 0.549280 0.518390 0.018859 0.548797 \n", | |
"9 MTORC1 0.529677 0.573706 0.545634 0.017579 0.567938 \n", | |
"10 OPRK1 0.423600 0.423600 0.423600 0.000000 0.423600 \n", | |
"11 PKM2 0.517113 0.550826 0.540974 0.011080 0.537580 \n", | |
"12 PPARG 0.330807 0.381156 0.353604 0.017436 0.362762 \n", | |
"13 TP53 0.429233 0.429233 0.429233 0.000000 0.429233 \n", | |
"14 VDR 0.551515 0.551515 0.551515 0.000000 0.551515 \n", | |
"\n", | |
" ef1_min ef1_max ef1_mean ef1_std ef1_fused \n", | |
"0 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n", | |
"1 1.632254 1.632254 1.632254 0.000000e+00 1.632254 \n", | |
"2 0.000000 15.384615 1.025641 3.837597e+00 7.692308 \n", | |
"3 0.980392 4.901961 4.248366 9.912255e-01 3.921569 \n", | |
"4 2.439024 2.439024 2.439024 0.000000e+00 2.439024 \n", | |
"5 0.602410 1.204819 1.004016 2.839786e-01 0.602410 \n", | |
"6 5.128205 5.128205 5.128205 1.776357e-15 5.128205 \n", | |
"7 0.515464 1.030928 0.687285 2.429920e-01 0.515464 \n", | |
"8 0.324675 7.142857 2.337662 2.104804e+00 7.142857 \n", | |
"9 1.030928 3.092784 1.499531 6.758297e-01 2.061856 \n", | |
"10 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n", | |
"11 0.732601 0.915751 0.895401 5.755855e-02 0.915751 \n", | |
"12 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n", | |
"13 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n", | |
"14 2.036199 2.036199 2.036199 0.000000e+00 2.036199 " | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"benchmark" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"benchmark.to_csv(BENCH_OUTPUT_PATH, index=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.5087711290605641, 2.2725264454871454)" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"benchmark['auroc_fused'].mean(), benchmark['ef1_fused'].mean()" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Ligand-ligand matching" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# benchmark = run_benchmark(DATA_SOURCE, DATA_DIR, OUTPUT_DIR, \\\n", | |
"# partial(extract_fingerprints, model=model, save_dir=OUTPUT_DIR),\n", | |
"# calc_cosine_distance, start_idx=5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>target</th>\n", | |
" <th>auroc_min</th>\n", | |
" <th>auroc_max</th>\n", | |
" <th>auroc_mean</th>\n", | |
" <th>auroc_std</th>\n", | |
" <th>auroc_fused</th>\n", | |
" <th>ef1_min</th>\n", | |
" <th>ef1_max</th>\n", | |
" <th>ef1_mean</th>\n", | |
" <th>ef1_std</th>\n", | |
" <th>ef1_fused</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>ADRB2</td>\n", | |
" <td>0.507079</td>\n", | |
" <td>0.643049</td>\n", | |
" <td>0.595389</td>\n", | |
" <td>0.037026</td>\n", | |
" <td>0.586201</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>17.647059</td>\n", | |
" <td>9.558824</td>\n", | |
" <td>7.166761</td>\n", | |
" <td>5.882353</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>ALDH1</td>\n", | |
" <td>0.496826</td>\n", | |
" <td>0.527855</td>\n", | |
" <td>0.510553</td>\n", | |
" <td>0.011711</td>\n", | |
" <td>0.497737</td>\n", | |
" <td>0.725446</td>\n", | |
" <td>1.185826</td>\n", | |
" <td>1.018415</td>\n", | |
" <td>0.180015</td>\n", | |
" <td>0.739397</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>ESR1_ago</td>\n", | |
" <td>0.531634</td>\n", | |
" <td>0.765724</td>\n", | |
" <td>0.647954</td>\n", | |
" <td>0.054484</td>\n", | |
" <td>0.636906</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>23.076923</td>\n", | |
" <td>12.820513</td>\n", | |
" <td>8.268982</td>\n", | |
" <td>15.384615</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>ESR1_ant</td>\n", | |
" <td>0.414653</td>\n", | |
" <td>0.522583</td>\n", | |
" <td>0.480341</td>\n", | |
" <td>0.026942</td>\n", | |
" <td>0.471374</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>2.941176</td>\n", | |
" <td>1.330532</td>\n", | |
" <td>0.952484</td>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>FEN1</td>\n", | |
" <td>0.415627</td>\n", | |
" <td>0.415627</td>\n", | |
" <td>0.415627</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.415627</td>\n", | |
" <td>0.813008</td>\n", | |
" <td>0.813008</td>\n", | |
" <td>0.813008</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.813008</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>GBA</td>\n", | |
" <td>0.407294</td>\n", | |
" <td>0.496951</td>\n", | |
" <td>0.444367</td>\n", | |
" <td>0.035076</td>\n", | |
" <td>0.451171</td>\n", | |
" <td>1.204819</td>\n", | |
" <td>3.012048</td>\n", | |
" <td>1.957831</td>\n", | |
" <td>0.656461</td>\n", | |
" <td>2.409639</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>IDH1</td>\n", | |
" <td>0.460212</td>\n", | |
" <td>0.603055</td>\n", | |
" <td>0.520860</td>\n", | |
" <td>0.042039</td>\n", | |
" <td>0.570258</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>5.128205</td>\n", | |
" <td>1.282051</td>\n", | |
" <td>1.720052</td>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>KAT2A</td>\n", | |
" <td>0.516656</td>\n", | |
" <td>0.532710</td>\n", | |
" <td>0.524683</td>\n", | |
" <td>0.008027</td>\n", | |
" <td>0.538545</td>\n", | |
" <td>0.515464</td>\n", | |
" <td>1.030928</td>\n", | |
" <td>0.773196</td>\n", | |
" <td>0.257732</td>\n", | |
" <td>1.030928</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>MAPK1</td>\n", | |
" <td>0.450370</td>\n", | |
" <td>0.580198</td>\n", | |
" <td>0.510623</td>\n", | |
" <td>0.038181</td>\n", | |
" <td>0.552040</td>\n", | |
" <td>0.324675</td>\n", | |
" <td>8.116883</td>\n", | |
" <td>2.479339</td>\n", | |
" <td>2.579902</td>\n", | |
" <td>2.922078</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>MTORC1</td>\n", | |
" <td>0.490265</td>\n", | |
" <td>0.506024</td>\n", | |
" <td>0.496381</td>\n", | |
" <td>0.005229</td>\n", | |
" <td>0.506741</td>\n", | |
" <td>1.030928</td>\n", | |
" <td>5.154639</td>\n", | |
" <td>4.123711</td>\n", | |
" <td>1.304032</td>\n", | |
" <td>3.092784</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>10</th>\n", | |
" <td>OPRK1</td>\n", | |
" <td>0.624212</td>\n", | |
" <td>0.624212</td>\n", | |
" <td>0.624212</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.624212</td>\n", | |
" <td>4.166667</td>\n", | |
" <td>4.166667</td>\n", | |
" <td>4.166667</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>4.166667</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11</th>\n", | |
" <td>PKM2</td>\n", | |
" <td>0.509071</td>\n", | |
" <td>0.607490</td>\n", | |
" <td>0.543089</td>\n", | |
" <td>0.036738</td>\n", | |
" <td>0.585025</td>\n", | |
" <td>0.366300</td>\n", | |
" <td>6.959707</td>\n", | |
" <td>3.044872</td>\n", | |
" <td>2.568902</td>\n", | |
" <td>6.593407</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>12</th>\n", | |
" <td>PPARG</td>\n", | |
" <td>0.282177</td>\n", | |
" <td>0.678952</td>\n", | |
" <td>0.465009</td>\n", | |
" <td>0.118151</td>\n", | |
" <td>0.676571</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>14.814815</td>\n", | |
" <td>2.160494</td>\n", | |
" <td>4.397471</td>\n", | |
" <td>11.111111</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>13</th>\n", | |
" <td>TP53</td>\n", | |
" <td>0.362562</td>\n", | |
" <td>0.547525</td>\n", | |
" <td>0.467482</td>\n", | |
" <td>0.066332</td>\n", | |
" <td>0.510018</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>2.531646</td>\n", | |
" <td>1.265823</td>\n", | |
" <td>1.132186</td>\n", | |
" <td>1.265823</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>14</th>\n", | |
" <td>VDR</td>\n", | |
" <td>0.582428</td>\n", | |
" <td>0.582428</td>\n", | |
" <td>0.582428</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.582428</td>\n", | |
" <td>4.185520</td>\n", | |
" <td>4.185520</td>\n", | |
" <td>4.185520</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>4.185520</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" target auroc_min auroc_max auroc_mean auroc_std auroc_fused \\\n", | |
"0 ADRB2 0.507079 0.643049 0.595389 0.037026 0.586201 \n", | |
"1 ALDH1 0.496826 0.527855 0.510553 0.011711 0.497737 \n", | |
"2 ESR1_ago 0.531634 0.765724 0.647954 0.054484 0.636906 \n", | |
"3 ESR1_ant 0.414653 0.522583 0.480341 0.026942 0.471374 \n", | |
"4 FEN1 0.415627 0.415627 0.415627 0.000000 0.415627 \n", | |
"5 GBA 0.407294 0.496951 0.444367 0.035076 0.451171 \n", | |
"6 IDH1 0.460212 0.603055 0.520860 0.042039 0.570258 \n", | |
"7 KAT2A 0.516656 0.532710 0.524683 0.008027 0.538545 \n", | |
"8 MAPK1 0.450370 0.580198 0.510623 0.038181 0.552040 \n", | |
"9 MTORC1 0.490265 0.506024 0.496381 0.005229 0.506741 \n", | |
"10 OPRK1 0.624212 0.624212 0.624212 0.000000 0.624212 \n", | |
"11 PKM2 0.509071 0.607490 0.543089 0.036738 0.585025 \n", | |
"12 PPARG 0.282177 0.678952 0.465009 0.118151 0.676571 \n", | |
"13 TP53 0.362562 0.547525 0.467482 0.066332 0.510018 \n", | |
"14 VDR 0.582428 0.582428 0.582428 0.000000 0.582428 \n", | |
"\n", | |
" ef1_min ef1_max ef1_mean ef1_std ef1_fused \n", | |
"0 0.000000 17.647059 9.558824 7.166761 5.882353 \n", | |
"1 0.725446 1.185826 1.018415 0.180015 0.739397 \n", | |
"2 0.000000 23.076923 12.820513 8.268982 15.384615 \n", | |
"3 0.000000 2.941176 1.330532 0.952484 0.000000 \n", | |
"4 0.813008 0.813008 0.813008 0.000000 0.813008 \n", | |
"5 1.204819 3.012048 1.957831 0.656461 2.409639 \n", | |
"6 0.000000 5.128205 1.282051 1.720052 0.000000 \n", | |
"7 0.515464 1.030928 0.773196 0.257732 1.030928 \n", | |
"8 0.324675 8.116883 2.479339 2.579902 2.922078 \n", | |
"9 1.030928 5.154639 4.123711 1.304032 3.092784 \n", | |
"10 4.166667 4.166667 4.166667 0.000000 4.166667 \n", | |
"11 0.366300 6.959707 3.044872 2.568902 6.593407 \n", | |
"12 0.000000 14.814815 2.160494 4.397471 11.111111 \n", | |
"13 0.000000 2.531646 1.265823 1.132186 1.265823 \n", | |
"14 4.185520 4.185520 4.185520 0.000000 4.185520 " | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"benchmark" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"benchmark.to_csv(BENCH_OUTPUT_PATH, index=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.5469901404486206, 3.9731552741192093)" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"benchmark['auroc_fused'].mean(), benchmark['ef1_fused'].mean()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "chem", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.4" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from sklearn.metrics import roc_auc_score | |
from tqdm.notebook import tqdm | |
from rdkit.Chem.rdmolfiles import MolFromMol2File, MolToSmiles | |
from metrics import ef_score | |
def read_litpcba_smiles(file_path): | |
smiles = pd.read_csv(file_path, header=None, names=['smiles', 'code'], sep=' ') | |
return smiles['smiles'].tolist() | |
def read_dude_smiles(file_path): | |
smiles = pd.read_csv(file_path, header=None, names=['smiles', 'code', 'chembl_code'], sep=' ') | |
return smiles['smiles'].tolist() | |
def read_actives_decoys(data_source, target_dir): | |
if data_source == "litpcba": | |
actives_fpath = target_dir / "actives.smi" | |
decoys_fpath = target_dir / "inactives.smi" | |
read_smiles_fn = read_litpcba_smiles | |
elif data_source == "dude": | |
actives_fpath = target_dir / "actives_final.ism" | |
decoys_fpath = target_dir / "decoys_final.ism" | |
read_smiles_fn = read_dude_smiles | |
actives_smiles = read_smiles_fn(actives_fpath) | |
decoys_smiles = read_smiles_fn(decoys_fpath) | |
return actives_smiles, decoys_smiles | |
def read_ligands(data_source, target_dir): | |
if data_source == "litpcba": | |
ligands_paths = list(target_dir.glob("*_ligand.mol2")) | |
elif data_source == "dude": | |
ligands_paths = [target_dir / "crystal_ligand.mol2"] | |
ligands = [MolFromMol2File(str(p)) for p in ligands_paths] | |
total_ligands = len(ligands) | |
ligands = [lig for lig in ligands if lig is not None] | |
print(f"Loaded {len(ligands)} out of {total_ligands} ligands") | |
return ligands | |
def calc_euclidean_distance(x, y): | |
dist = (x ** 2).sum(axis=1)[:, None] + (y ** 2).sum(axis=1)[None, :] - 2 * x @ y.T | |
return -dist | |
def calc_cosine_distance(x, y): | |
dist = x @ y.T / (np.linalg.norm(x, axis=1)[:, None] * np.linalg.norm(y, axis=1)[None, :] + 1e-6) | |
dist = (dist + 1.) / 2. | |
return dist | |
def run_benchmark(data_source, data_dir, output_dir, extract_fingerprints_fn, similarity_fn, start_idx=None): | |
benchmark = { | |
'target': [], | |
'auroc_min': [], | |
'auroc_max': [], | |
'auroc_mean': [], | |
'auroc_std': [], | |
'auroc_fused': [], | |
'ef1_min': [], | |
'ef1_max': [], | |
'ef1_mean': [], | |
'ef1_std': [], | |
'ef1_fused': [] | |
} | |
target_dirs = list(sorted(data_dir.iterdir())) | |
for t, target_dir in enumerate(target_dirs): | |
print(f'{target_dir.stem} ({t+1}/{len(target_dirs)})') | |
target_output_dir = output_dir / target_dir.stem | |
if start_idx is not None and t < start_idx: | |
if not target_output_dir.exists(): | |
print(f'[{target_dir.stem}] No output dir found') | |
continue | |
ligands_fps = np.load(target_output_dir / 'ligands_feats.npy') | |
actives_fps = np.load(target_output_dir / 'actives_feats.npy') | |
decoys_fps = np.load(target_output_dir / 'decoys_feats.npy') | |
else: | |
ligands = read_ligands(data_source, target_dir) | |
if len(ligands) == 0: | |
print(f'[{target_dir.stem}] No ligands found or they may be corrupted') | |
continue | |
actives_smiles, decoys_smiles = read_actives_decoys(data_source, target_dir) | |
actives_fps = extract_fingerprints_fn(actives_smiles) | |
decoys_fps = extract_fingerprints_fn(decoys_smiles) | |
target_output_dir.mkdir(exist_ok=True, parents=True) | |
np.save(target_output_dir / 'actives_feats.npy', actives_fps) | |
np.save(target_output_dir / 'decoys_feats.npy', decoys_fps) | |
ligands_smiles = [MolToSmiles(ligand) for ligand in ligands] | |
ligands_fps = extract_fingerprints_fn(ligands_smiles) | |
ligands_fps = np.stack(ligands_fps) | |
np.save(target_output_dir / 'ligands_feats.npy', ligands_fps) | |
labels = np.concatenate([np.ones(len(actives_fps), dtype=bool), np.zeros(len(decoys_fps), dtype=bool)]) | |
actives_fps = actives_fps.astype(np.float32) | |
decoys_fps = decoys_fps.astype(np.float32) | |
ligands_fps = ligands_fps.astype(np.float32) | |
lig2act = similarity_fn(ligands_fps, actives_fps) | |
lig2dec = similarity_fn(ligands_fps, decoys_fps) | |
roc_aucs, ef1s = [], [] | |
for i in tqdm(range(lig2act.shape[0])): | |
roc_auc = roc_auc_score(labels, np.concatenate([lig2act[i], lig2dec[i]])) | |
roc_aucs.append(roc_auc) | |
ef1 = ef_score(labels, np.concatenate([lig2act[i], lig2dec[i]]), 0.01) | |
ef1s.append(ef1) | |
roc_auc_fused = roc_auc_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)])) | |
ef1_fused = ef_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)]), 0.01) | |
benchmark['target'].append(target_dir.stem) | |
benchmark['auroc_min'].append(np.min(roc_aucs)) | |
benchmark['auroc_max'].append(np.max(roc_aucs)) | |
benchmark['auroc_mean'].append(np.mean(roc_aucs)) | |
benchmark['auroc_std'].append(np.std(roc_aucs)) | |
benchmark['auroc_fused'].append(roc_auc_fused) | |
benchmark['ef1_min'].append(np.min(ef1s)) | |
benchmark['ef1_max'].append(np.max(ef1s)) | |
benchmark['ef1_mean'].append(np.mean(ef1s)) | |
benchmark['ef1_std'].append(np.std(ef1s)) | |
benchmark['ef1_fused'].append(ef1_fused) | |
print(f'[{target_dir.stem}] ROC AUC min: {np.min(roc_aucs):.3f}, max: {np.max(roc_aucs):.3f}, ' | |
f'mean: {np.mean(roc_aucs):.3f}, std: {np.std(roc_aucs):.3f} fused: {roc_auc_fused:.3f}') | |
print(f'[{target_dir.stem}] EF1 min: {np.min(ef1s):.3f}, max: {np.max(ef1s):.3f}, ' | |
f'mean: {np.mean(ef1s):.3f}, std: {np.std(ef1s):.3f} fused: {ef1_fused:.3f}') | |
benchmark = pd.DataFrame(benchmark) | |
return benchmark |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment