Skip to content

Instantly share code, notes, and snippets.

@VladVin
Last active July 19, 2023 17:26
Show Gist options
  • Save VladVin/4121720ff984693f9b8c34c29e47f652 to your computer and use it in GitHub Desktop.
Save VladVin/4121720ff984693f9b8c34c29e47f652 to your computer and use it in GitHub Desktop.
ConPLex evaluation for protein-ligand and ligand-ligand matching
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"from multiprocessing import Pool\n",
"import multiprocessing as mp\n",
"from pathlib import Path\n",
"import os\n",
"import sys\n",
"\n",
"from Bio import Seq, SeqIO\n",
"import numpy as np\n",
"import pandas as pd\n",
"from rdkit.Chem import AllChem\n",
"from rdkit import Chem\n",
"from rdkit.Chem.rdmolfiles import MolFromMol2File, MolToSmiles, MolFromSmiles\n",
"from sklearn.metrics import roc_auc_score\n",
"import torch\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"sys.path.append(\"./ConPlex/\")\n",
"from conplex_dti.model.architectures import SimpleCoembedding\n",
"from conplex_dti.featurizer.molecule import MorganFeaturizer\n",
"from conplex_dti.featurizer.protein import ProtBertFeaturizer\n",
"\n",
"from conplex_dataset import ConPlexDrugDataset\n",
"from metrics import ef_score\n",
"from utils import read_actives_decoys, read_ligands, calc_euclidean_distance, calc_cosine_distance, run_benchmark"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model = SimpleCoembedding()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"DEVICE = torch.device(\"mps\")\n",
"# DEVICE = torch.device(\"cuda:0\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"ckpt = torch.load(\"./ConPlex/weights/BindingDB_ExperimentalValidModel.pt\", map_location=DEVICE)\n",
"model.load_state_dict(ckpt)\n",
"model.eval()\n",
"model = model.to(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"DATA_SOURCE = \"litpcba\"\n",
"# DATA_SOURCE = \"dude\""
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"if DATA_SOURCE == \"litpcba\":\n",
" DATA_DIR = Path(\"./data/LIT-PCBA/full_data\")\n",
" OUTPUT_DIR = Path('./output/LIT-PCBA/full_conplex_feats')\n",
" BENCH_OUTPUT_PATH = Path('./output/LIT-PCBA/benchmark_full_conplex_prot2lig.csv')\n",
"elif DATA_SOURCE == \"dude\":\n",
" DATA_DIR = Path(\"./data/dud-e\")\n",
" OUTPUT_DIR = Path('./output/dud-e/conplex_feats')\n",
" BENCH_OUTPUT_PATH = Path('./output/dud-e/benchmark_conplex_prot2lig.csv')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"if DATA_SOURCE == \"litpcba\":\n",
" target_dir = DATA_DIR / \"ADRB2\"\n",
"elif DATA_SOURCE == \"dude\":\n",
" target_dir = DATA_DIR / \"abl1\"\n",
"\n",
"target_output_dir = OUTPUT_DIR / target_dir.stem\n",
"actives_smiles, decoys_smiles = read_actives_decoys(DATA_SOURCE, target_dir)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"mol_featurizer = MorganFeaturizer(save_dir=target_output_dir)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0., 1., 0., ..., 0., 0., 0.])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mol_featurizer.transform(actives_smiles[0])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"prot_featurizer = ProtBertFeaturizer(save_dir=target_output_dir)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"mol_feats = mol_featurizer.transform(actives_smiles[0])\n",
"mol_embed = model.project_drug(mol_feats.unsqueeze(0).to(DEVICE))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def extract_fingerprints_for_ligands(proteins_all_sequences: list[list[Seq]], prot_featurizer, model, save_dir):\n",
" proteins_sequences = []\n",
" for protein_sequences in proteins_all_sequences:\n",
" joint_sequence = ''.join([str(seq) for seq in protein_sequences])\n",
" proteins_sequences.append(joint_sequence)\n",
" \n",
" proteins_features = []\n",
" for protein_sequence in tqdm(proteins_sequences):\n",
" protein_features = prot_featurizer.transform(protein_sequence)\n",
" proteins_features.append(protein_features)\n",
" proteins_features = torch.stack(proteins_features)\n",
" \n",
" batch_size = 8\n",
" fps = []\n",
" for i in range(0, len(proteins_features), batch_size):\n",
" batch = proteins_features[i:i+batch_size]\n",
" batch = batch.to(DEVICE)\n",
" with torch.no_grad():\n",
" prot_embed = model.project_target(batch)\n",
" fps.append(prot_embed.cpu().numpy())\n",
"\n",
" return np.concatenate(fps)\n",
"\n",
"def extract_fingerprints(smiles, model, save_dir):\n",
" dataset = ConPlexDrugDataset(smiles, save_dir=save_dir)\n",
" dataloader = DataLoader(dataset,\n",
" batch_size=512,\n",
" shuffle=False,\n",
" num_workers=os.cpu_count())\n",
" \n",
" fps = []\n",
" for batch in tqdm(dataloader, total=len(dataloader)):\n",
" batch = batch.to(DEVICE)\n",
" with torch.no_grad():\n",
" mol_embed = model.project_drug(batch)\n",
" fps.append(mol_embed.cpu().numpy())\n",
" \n",
" return np.concatenate(fps)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fd7b4c5cea834a0aae3071036adcad49",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c98a81430c6145e4bb4875518c0a7f1c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/611 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"actives_fps = extract_fingerprints(actives_smiles, model, target_output_dir)\n",
"decoys_fps = extract_fingerprints(decoys_smiles, model, target_output_dir)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def read_proteins(target_dir):\n",
" proteins_all_sequences = []\n",
" for entry_path in sorted(target_dir.glob('*_protein.fasta')):\n",
" proteins_all_sequences.append(SeqIO.parse(str(entry_path), 'fasta'))\n",
" \n",
" return proteins_all_sequences"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5c30b774ad140729cf16d9fd21ab045",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(8, 1024)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# ligands_fps = extract_fingerprints([MolToSmiles(l) for l in ligands], model, target_output_dir)\n",
"# ligands_fps = np.array(ligands_fps)\n",
"\n",
"proteins_all_sequences = read_proteins(target_dir)\n",
"ligands_fps = extract_fingerprints_for_ligands(proteins_all_sequences, prot_featurizer, model, target_output_dir)\n",
"\n",
"ligands_fps.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"labels = np.concatenate([np.ones(len(actives_fps)), np.zeros(len(decoys_fps))]).astype(bool)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"lig2act = calc_cosine_distance(ligands_fps, actives_fps)\n",
"lig2dec = calc_cosine_distance(ligands_fps, decoys_fps)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "183ef4357bb8446fb361b6282473b6d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"roc_aucs = []\n",
"for i in tqdm(range(lig2act.shape[0])):\n",
" roc_auc = roc_auc_score(labels, np.concatenate([lig2act[i], lig2dec[i]]))\n",
" roc_aucs.append(roc_auc)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ROC AUC min: 0.467, max: 0.581, mean: 0.494, std: 0.034\n"
]
}
],
"source": [
"print(f'ROC AUC min: {np.min(roc_aucs):.3f}, max: {np.max(roc_aucs):.3f}, mean: {np.mean(roc_aucs):.3f}, std: {np.std(roc_aucs):.3f}')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Protein-ligand matching"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"benchmark = {\n",
" 'target': [],\n",
" 'auroc_min': [],\n",
" 'auroc_max': [],\n",
" 'auroc_mean': [],\n",
" 'auroc_std': [],\n",
" 'auroc_fused': [],\n",
" 'ef1_min': [],\n",
" 'ef1_max': [],\n",
" 'ef1_mean': [],\n",
" 'ef1_std': [],\n",
" 'ef1_fused': []\n",
"}\n",
"\n",
"start_idx = None\n",
"similarity_fn = calc_cosine_distance\n",
"target_dirs = list(sorted(DATA_DIR.iterdir()))\n",
"for t, target_dir in enumerate(target_dirs):\n",
" print(f'{target_dir.stem} ({t+1}/{len(target_dirs)})')\n",
" target_output_dir = OUTPUT_DIR / target_dir.stem\n",
" \n",
" if start_idx is not None and t < start_idx:\n",
" if not target_output_dir.exists():\n",
" print(f'[{target_dir.stem}] No output dir found')\n",
" continue\n",
" \n",
" ligands_fps = np.load(target_output_dir / 'ligands_feats.npy')\n",
" actives_fps = np.load(target_output_dir / 'actives_feats.npy')\n",
" decoys_fps = np.load(target_output_dir / 'decoys_feats.npy')\n",
" else:\n",
" proteins_all_sequences = read_proteins(target_dir)\n",
"\n",
" actives_smiles, decoys_smiles = read_actives_decoys(DATA_SOURCE, target_dir)\n",
" actives_fps = extract_fingerprints(actives_smiles, model=model, save_dir=target_output_dir)\n",
" decoys_fps = extract_fingerprints(decoys_smiles, model=model, save_dir=target_output_dir)\n",
" target_output_dir.mkdir(exist_ok=True, parents=True)\n",
" np.save(target_output_dir / 'actives_feats.npy', actives_fps)\n",
" np.save(target_output_dir / 'decoys_feats.npy', decoys_fps)\n",
"\n",
" ligands_fps = extract_fingerprints_for_ligands(proteins_all_sequences, prot_featurizer, model, target_output_dir)\n",
" np.save(target_output_dir / 'ligands_feats.npy', ligands_fps)\n",
"\n",
" labels = np.concatenate([np.ones(len(actives_fps), dtype=bool), np.zeros(len(decoys_fps), dtype=bool)])\n",
" actives_fps = actives_fps.astype(np.float32)\n",
" decoys_fps = decoys_fps.astype(np.float32)\n",
" ligands_fps = ligands_fps.astype(np.float32)\n",
" lig2act = similarity_fn(ligands_fps, actives_fps)\n",
" lig2dec = similarity_fn(ligands_fps, decoys_fps)\n",
"\n",
" roc_aucs, ef1s = [], []\n",
" for i in tqdm(range(lig2act.shape[0])):\n",
" roc_auc = roc_auc_score(labels, np.concatenate([lig2act[i], lig2dec[i]]))\n",
" roc_aucs.append(roc_auc)\n",
" ef1 = ef_score(labels, np.concatenate([lig2act[i], lig2dec[i]]), 0.01)\n",
" ef1s.append(ef1)\n",
" roc_auc_fused = roc_auc_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)]))\n",
" ef1_fused = ef_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)]), 0.01)\n",
"\n",
" benchmark['target'].append(target_dir.stem)\n",
" benchmark['auroc_min'].append(np.min(roc_aucs))\n",
" benchmark['auroc_max'].append(np.max(roc_aucs))\n",
" benchmark['auroc_mean'].append(np.mean(roc_aucs))\n",
" benchmark['auroc_std'].append(np.std(roc_aucs))\n",
" benchmark['auroc_fused'].append(roc_auc_fused)\n",
" benchmark['ef1_min'].append(np.min(ef1s))\n",
" benchmark['ef1_max'].append(np.max(ef1s))\n",
" benchmark['ef1_mean'].append(np.mean(ef1s))\n",
" benchmark['ef1_std'].append(np.std(ef1s))\n",
" benchmark['ef1_fused'].append(ef1_fused)\n",
"\n",
" print(f'[{target_dir.stem}] ROC AUC min: {np.min(roc_aucs):.3f}, max: {np.max(roc_aucs):.3f}, '\n",
" f'mean: {np.mean(roc_aucs):.3f}, std: {np.std(roc_aucs):.3f} fused: {roc_auc_fused:.3f}')\n",
" print(f'[{target_dir.stem}] EF1 min: {np.min(ef1s):.3f}, max: {np.max(ef1s):.3f}, '\n",
" f'mean: {np.mean(ef1s):.3f}, std: {np.std(ef1s):.3f} fused: {ef1_fused:.3f}')\n",
" \n",
"benchmark = pd.DataFrame(benchmark)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>target</th>\n",
" <th>auroc_min</th>\n",
" <th>auroc_max</th>\n",
" <th>auroc_mean</th>\n",
" <th>auroc_std</th>\n",
" <th>auroc_fused</th>\n",
" <th>ef1_min</th>\n",
" <th>ef1_max</th>\n",
" <th>ef1_mean</th>\n",
" <th>ef1_std</th>\n",
" <th>ef1_fused</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ADRB2</td>\n",
" <td>0.467131</td>\n",
" <td>0.580804</td>\n",
" <td>0.494027</td>\n",
" <td>0.034033</td>\n",
" <td>0.497650</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ALDH1</td>\n",
" <td>0.525737</td>\n",
" <td>0.525737</td>\n",
" <td>0.525737</td>\n",
" <td>0.000000</td>\n",
" <td>0.525737</td>\n",
" <td>1.632254</td>\n",
" <td>1.632254</td>\n",
" <td>1.632254</td>\n",
" <td>0.000000e+00</td>\n",
" <td>1.632254</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ESR1_ago</td>\n",
" <td>0.431165</td>\n",
" <td>0.566769</td>\n",
" <td>0.455523</td>\n",
" <td>0.031804</td>\n",
" <td>0.560775</td>\n",
" <td>0.000000</td>\n",
" <td>15.384615</td>\n",
" <td>1.025641</td>\n",
" <td>3.837597e+00</td>\n",
" <td>7.692308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ESR1_ant</td>\n",
" <td>0.463507</td>\n",
" <td>0.487013</td>\n",
" <td>0.473227</td>\n",
" <td>0.007324</td>\n",
" <td>0.471104</td>\n",
" <td>0.980392</td>\n",
" <td>4.901961</td>\n",
" <td>4.248366</td>\n",
" <td>9.912255e-01</td>\n",
" <td>3.921569</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>FEN1</td>\n",
" <td>0.578497</td>\n",
" <td>0.578497</td>\n",
" <td>0.578497</td>\n",
" <td>0.000000</td>\n",
" <td>0.578497</td>\n",
" <td>2.439024</td>\n",
" <td>2.439024</td>\n",
" <td>2.439024</td>\n",
" <td>0.000000e+00</td>\n",
" <td>2.439024</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>GBA</td>\n",
" <td>0.484791</td>\n",
" <td>0.493707</td>\n",
" <td>0.487763</td>\n",
" <td>0.004203</td>\n",
" <td>0.488888</td>\n",
" <td>0.602410</td>\n",
" <td>1.204819</td>\n",
" <td>1.004016</td>\n",
" <td>2.839786e-01</td>\n",
" <td>0.602410</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>IDH1</td>\n",
" <td>0.529484</td>\n",
" <td>0.553243</td>\n",
" <td>0.545412</td>\n",
" <td>0.008663</td>\n",
" <td>0.545899</td>\n",
" <td>5.128205</td>\n",
" <td>5.128205</td>\n",
" <td>5.128205</td>\n",
" <td>1.776357e-15</td>\n",
" <td>5.128205</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>KAT2A</td>\n",
" <td>0.538885</td>\n",
" <td>0.543235</td>\n",
" <td>0.540335</td>\n",
" <td>0.002050</td>\n",
" <td>0.541592</td>\n",
" <td>0.515464</td>\n",
" <td>1.030928</td>\n",
" <td>0.687285</td>\n",
" <td>2.429920e-01</td>\n",
" <td>0.515464</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>MAPK1</td>\n",
" <td>0.482031</td>\n",
" <td>0.549280</td>\n",
" <td>0.518390</td>\n",
" <td>0.018859</td>\n",
" <td>0.548797</td>\n",
" <td>0.324675</td>\n",
" <td>7.142857</td>\n",
" <td>2.337662</td>\n",
" <td>2.104804e+00</td>\n",
" <td>7.142857</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>MTORC1</td>\n",
" <td>0.529677</td>\n",
" <td>0.573706</td>\n",
" <td>0.545634</td>\n",
" <td>0.017579</td>\n",
" <td>0.567938</td>\n",
" <td>1.030928</td>\n",
" <td>3.092784</td>\n",
" <td>1.499531</td>\n",
" <td>6.758297e-01</td>\n",
" <td>2.061856</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>OPRK1</td>\n",
" <td>0.423600</td>\n",
" <td>0.423600</td>\n",
" <td>0.423600</td>\n",
" <td>0.000000</td>\n",
" <td>0.423600</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>PKM2</td>\n",
" <td>0.517113</td>\n",
" <td>0.550826</td>\n",
" <td>0.540974</td>\n",
" <td>0.011080</td>\n",
" <td>0.537580</td>\n",
" <td>0.732601</td>\n",
" <td>0.915751</td>\n",
" <td>0.895401</td>\n",
" <td>5.755855e-02</td>\n",
" <td>0.915751</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>PPARG</td>\n",
" <td>0.330807</td>\n",
" <td>0.381156</td>\n",
" <td>0.353604</td>\n",
" <td>0.017436</td>\n",
" <td>0.362762</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>TP53</td>\n",
" <td>0.429233</td>\n",
" <td>0.429233</td>\n",
" <td>0.429233</td>\n",
" <td>0.000000</td>\n",
" <td>0.429233</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>VDR</td>\n",
" <td>0.551515</td>\n",
" <td>0.551515</td>\n",
" <td>0.551515</td>\n",
" <td>0.000000</td>\n",
" <td>0.551515</td>\n",
" <td>2.036199</td>\n",
" <td>2.036199</td>\n",
" <td>2.036199</td>\n",
" <td>0.000000e+00</td>\n",
" <td>2.036199</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" target auroc_min auroc_max auroc_mean auroc_std auroc_fused \\\n",
"0 ADRB2 0.467131 0.580804 0.494027 0.034033 0.497650 \n",
"1 ALDH1 0.525737 0.525737 0.525737 0.000000 0.525737 \n",
"2 ESR1_ago 0.431165 0.566769 0.455523 0.031804 0.560775 \n",
"3 ESR1_ant 0.463507 0.487013 0.473227 0.007324 0.471104 \n",
"4 FEN1 0.578497 0.578497 0.578497 0.000000 0.578497 \n",
"5 GBA 0.484791 0.493707 0.487763 0.004203 0.488888 \n",
"6 IDH1 0.529484 0.553243 0.545412 0.008663 0.545899 \n",
"7 KAT2A 0.538885 0.543235 0.540335 0.002050 0.541592 \n",
"8 MAPK1 0.482031 0.549280 0.518390 0.018859 0.548797 \n",
"9 MTORC1 0.529677 0.573706 0.545634 0.017579 0.567938 \n",
"10 OPRK1 0.423600 0.423600 0.423600 0.000000 0.423600 \n",
"11 PKM2 0.517113 0.550826 0.540974 0.011080 0.537580 \n",
"12 PPARG 0.330807 0.381156 0.353604 0.017436 0.362762 \n",
"13 TP53 0.429233 0.429233 0.429233 0.000000 0.429233 \n",
"14 VDR 0.551515 0.551515 0.551515 0.000000 0.551515 \n",
"\n",
" ef1_min ef1_max ef1_mean ef1_std ef1_fused \n",
"0 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n",
"1 1.632254 1.632254 1.632254 0.000000e+00 1.632254 \n",
"2 0.000000 15.384615 1.025641 3.837597e+00 7.692308 \n",
"3 0.980392 4.901961 4.248366 9.912255e-01 3.921569 \n",
"4 2.439024 2.439024 2.439024 0.000000e+00 2.439024 \n",
"5 0.602410 1.204819 1.004016 2.839786e-01 0.602410 \n",
"6 5.128205 5.128205 5.128205 1.776357e-15 5.128205 \n",
"7 0.515464 1.030928 0.687285 2.429920e-01 0.515464 \n",
"8 0.324675 7.142857 2.337662 2.104804e+00 7.142857 \n",
"9 1.030928 3.092784 1.499531 6.758297e-01 2.061856 \n",
"10 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n",
"11 0.732601 0.915751 0.895401 5.755855e-02 0.915751 \n",
"12 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n",
"13 0.000000 0.000000 0.000000 0.000000e+00 0.000000 \n",
"14 2.036199 2.036199 2.036199 0.000000e+00 2.036199 "
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"benchmark"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"benchmark.to_csv(BENCH_OUTPUT_PATH, index=False)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.5087711290605641, 2.2725264454871454)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"benchmark['auroc_fused'].mean(), benchmark['ef1_fused'].mean()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Ligand-ligand matching"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# benchmark = run_benchmark(DATA_SOURCE, DATA_DIR, OUTPUT_DIR, \\\n",
"# partial(extract_fingerprints, model=model, save_dir=OUTPUT_DIR),\n",
"# calc_cosine_distance, start_idx=5)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>target</th>\n",
" <th>auroc_min</th>\n",
" <th>auroc_max</th>\n",
" <th>auroc_mean</th>\n",
" <th>auroc_std</th>\n",
" <th>auroc_fused</th>\n",
" <th>ef1_min</th>\n",
" <th>ef1_max</th>\n",
" <th>ef1_mean</th>\n",
" <th>ef1_std</th>\n",
" <th>ef1_fused</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ADRB2</td>\n",
" <td>0.507079</td>\n",
" <td>0.643049</td>\n",
" <td>0.595389</td>\n",
" <td>0.037026</td>\n",
" <td>0.586201</td>\n",
" <td>0.000000</td>\n",
" <td>17.647059</td>\n",
" <td>9.558824</td>\n",
" <td>7.166761</td>\n",
" <td>5.882353</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ALDH1</td>\n",
" <td>0.496826</td>\n",
" <td>0.527855</td>\n",
" <td>0.510553</td>\n",
" <td>0.011711</td>\n",
" <td>0.497737</td>\n",
" <td>0.725446</td>\n",
" <td>1.185826</td>\n",
" <td>1.018415</td>\n",
" <td>0.180015</td>\n",
" <td>0.739397</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ESR1_ago</td>\n",
" <td>0.531634</td>\n",
" <td>0.765724</td>\n",
" <td>0.647954</td>\n",
" <td>0.054484</td>\n",
" <td>0.636906</td>\n",
" <td>0.000000</td>\n",
" <td>23.076923</td>\n",
" <td>12.820513</td>\n",
" <td>8.268982</td>\n",
" <td>15.384615</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ESR1_ant</td>\n",
" <td>0.414653</td>\n",
" <td>0.522583</td>\n",
" <td>0.480341</td>\n",
" <td>0.026942</td>\n",
" <td>0.471374</td>\n",
" <td>0.000000</td>\n",
" <td>2.941176</td>\n",
" <td>1.330532</td>\n",
" <td>0.952484</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>FEN1</td>\n",
" <td>0.415627</td>\n",
" <td>0.415627</td>\n",
" <td>0.415627</td>\n",
" <td>0.000000</td>\n",
" <td>0.415627</td>\n",
" <td>0.813008</td>\n",
" <td>0.813008</td>\n",
" <td>0.813008</td>\n",
" <td>0.000000</td>\n",
" <td>0.813008</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>GBA</td>\n",
" <td>0.407294</td>\n",
" <td>0.496951</td>\n",
" <td>0.444367</td>\n",
" <td>0.035076</td>\n",
" <td>0.451171</td>\n",
" <td>1.204819</td>\n",
" <td>3.012048</td>\n",
" <td>1.957831</td>\n",
" <td>0.656461</td>\n",
" <td>2.409639</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>IDH1</td>\n",
" <td>0.460212</td>\n",
" <td>0.603055</td>\n",
" <td>0.520860</td>\n",
" <td>0.042039</td>\n",
" <td>0.570258</td>\n",
" <td>0.000000</td>\n",
" <td>5.128205</td>\n",
" <td>1.282051</td>\n",
" <td>1.720052</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>KAT2A</td>\n",
" <td>0.516656</td>\n",
" <td>0.532710</td>\n",
" <td>0.524683</td>\n",
" <td>0.008027</td>\n",
" <td>0.538545</td>\n",
" <td>0.515464</td>\n",
" <td>1.030928</td>\n",
" <td>0.773196</td>\n",
" <td>0.257732</td>\n",
" <td>1.030928</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>MAPK1</td>\n",
" <td>0.450370</td>\n",
" <td>0.580198</td>\n",
" <td>0.510623</td>\n",
" <td>0.038181</td>\n",
" <td>0.552040</td>\n",
" <td>0.324675</td>\n",
" <td>8.116883</td>\n",
" <td>2.479339</td>\n",
" <td>2.579902</td>\n",
" <td>2.922078</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>MTORC1</td>\n",
" <td>0.490265</td>\n",
" <td>0.506024</td>\n",
" <td>0.496381</td>\n",
" <td>0.005229</td>\n",
" <td>0.506741</td>\n",
" <td>1.030928</td>\n",
" <td>5.154639</td>\n",
" <td>4.123711</td>\n",
" <td>1.304032</td>\n",
" <td>3.092784</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>OPRK1</td>\n",
" <td>0.624212</td>\n",
" <td>0.624212</td>\n",
" <td>0.624212</td>\n",
" <td>0.000000</td>\n",
" <td>0.624212</td>\n",
" <td>4.166667</td>\n",
" <td>4.166667</td>\n",
" <td>4.166667</td>\n",
" <td>0.000000</td>\n",
" <td>4.166667</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>PKM2</td>\n",
" <td>0.509071</td>\n",
" <td>0.607490</td>\n",
" <td>0.543089</td>\n",
" <td>0.036738</td>\n",
" <td>0.585025</td>\n",
" <td>0.366300</td>\n",
" <td>6.959707</td>\n",
" <td>3.044872</td>\n",
" <td>2.568902</td>\n",
" <td>6.593407</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>PPARG</td>\n",
" <td>0.282177</td>\n",
" <td>0.678952</td>\n",
" <td>0.465009</td>\n",
" <td>0.118151</td>\n",
" <td>0.676571</td>\n",
" <td>0.000000</td>\n",
" <td>14.814815</td>\n",
" <td>2.160494</td>\n",
" <td>4.397471</td>\n",
" <td>11.111111</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>TP53</td>\n",
" <td>0.362562</td>\n",
" <td>0.547525</td>\n",
" <td>0.467482</td>\n",
" <td>0.066332</td>\n",
" <td>0.510018</td>\n",
" <td>0.000000</td>\n",
" <td>2.531646</td>\n",
" <td>1.265823</td>\n",
" <td>1.132186</td>\n",
" <td>1.265823</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>VDR</td>\n",
" <td>0.582428</td>\n",
" <td>0.582428</td>\n",
" <td>0.582428</td>\n",
" <td>0.000000</td>\n",
" <td>0.582428</td>\n",
" <td>4.185520</td>\n",
" <td>4.185520</td>\n",
" <td>4.185520</td>\n",
" <td>0.000000</td>\n",
" <td>4.185520</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" target auroc_min auroc_max auroc_mean auroc_std auroc_fused \\\n",
"0 ADRB2 0.507079 0.643049 0.595389 0.037026 0.586201 \n",
"1 ALDH1 0.496826 0.527855 0.510553 0.011711 0.497737 \n",
"2 ESR1_ago 0.531634 0.765724 0.647954 0.054484 0.636906 \n",
"3 ESR1_ant 0.414653 0.522583 0.480341 0.026942 0.471374 \n",
"4 FEN1 0.415627 0.415627 0.415627 0.000000 0.415627 \n",
"5 GBA 0.407294 0.496951 0.444367 0.035076 0.451171 \n",
"6 IDH1 0.460212 0.603055 0.520860 0.042039 0.570258 \n",
"7 KAT2A 0.516656 0.532710 0.524683 0.008027 0.538545 \n",
"8 MAPK1 0.450370 0.580198 0.510623 0.038181 0.552040 \n",
"9 MTORC1 0.490265 0.506024 0.496381 0.005229 0.506741 \n",
"10 OPRK1 0.624212 0.624212 0.624212 0.000000 0.624212 \n",
"11 PKM2 0.509071 0.607490 0.543089 0.036738 0.585025 \n",
"12 PPARG 0.282177 0.678952 0.465009 0.118151 0.676571 \n",
"13 TP53 0.362562 0.547525 0.467482 0.066332 0.510018 \n",
"14 VDR 0.582428 0.582428 0.582428 0.000000 0.582428 \n",
"\n",
" ef1_min ef1_max ef1_mean ef1_std ef1_fused \n",
"0 0.000000 17.647059 9.558824 7.166761 5.882353 \n",
"1 0.725446 1.185826 1.018415 0.180015 0.739397 \n",
"2 0.000000 23.076923 12.820513 8.268982 15.384615 \n",
"3 0.000000 2.941176 1.330532 0.952484 0.000000 \n",
"4 0.813008 0.813008 0.813008 0.000000 0.813008 \n",
"5 1.204819 3.012048 1.957831 0.656461 2.409639 \n",
"6 0.000000 5.128205 1.282051 1.720052 0.000000 \n",
"7 0.515464 1.030928 0.773196 0.257732 1.030928 \n",
"8 0.324675 8.116883 2.479339 2.579902 2.922078 \n",
"9 1.030928 5.154639 4.123711 1.304032 3.092784 \n",
"10 4.166667 4.166667 4.166667 0.000000 4.166667 \n",
"11 0.366300 6.959707 3.044872 2.568902 6.593407 \n",
"12 0.000000 14.814815 2.160494 4.397471 11.111111 \n",
"13 0.000000 2.531646 1.265823 1.132186 1.265823 \n",
"14 4.185520 4.185520 4.185520 0.000000 4.185520 "
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"benchmark"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"benchmark.to_csv(BENCH_OUTPUT_PATH, index=False)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.5469901404486206, 3.9731552741192093)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"benchmark['auroc_fused'].mean(), benchmark['ef1_fused'].mean()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chem",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
from tqdm.notebook import tqdm
from rdkit.Chem.rdmolfiles import MolFromMol2File, MolToSmiles
from metrics import ef_score
def read_litpcba_smiles(file_path):
smiles = pd.read_csv(file_path, header=None, names=['smiles', 'code'], sep=' ')
return smiles['smiles'].tolist()
def read_dude_smiles(file_path):
smiles = pd.read_csv(file_path, header=None, names=['smiles', 'code', 'chembl_code'], sep=' ')
return smiles['smiles'].tolist()
def read_actives_decoys(data_source, target_dir):
if data_source == "litpcba":
actives_fpath = target_dir / "actives.smi"
decoys_fpath = target_dir / "inactives.smi"
read_smiles_fn = read_litpcba_smiles
elif data_source == "dude":
actives_fpath = target_dir / "actives_final.ism"
decoys_fpath = target_dir / "decoys_final.ism"
read_smiles_fn = read_dude_smiles
actives_smiles = read_smiles_fn(actives_fpath)
decoys_smiles = read_smiles_fn(decoys_fpath)
return actives_smiles, decoys_smiles
def read_ligands(data_source, target_dir):
if data_source == "litpcba":
ligands_paths = list(target_dir.glob("*_ligand.mol2"))
elif data_source == "dude":
ligands_paths = [target_dir / "crystal_ligand.mol2"]
ligands = [MolFromMol2File(str(p)) for p in ligands_paths]
total_ligands = len(ligands)
ligands = [lig for lig in ligands if lig is not None]
print(f"Loaded {len(ligands)} out of {total_ligands} ligands")
return ligands
def calc_euclidean_distance(x, y):
dist = (x ** 2).sum(axis=1)[:, None] + (y ** 2).sum(axis=1)[None, :] - 2 * x @ y.T
return -dist
def calc_cosine_distance(x, y):
dist = x @ y.T / (np.linalg.norm(x, axis=1)[:, None] * np.linalg.norm(y, axis=1)[None, :] + 1e-6)
dist = (dist + 1.) / 2.
return dist
def run_benchmark(data_source, data_dir, output_dir, extract_fingerprints_fn, similarity_fn, start_idx=None):
benchmark = {
'target': [],
'auroc_min': [],
'auroc_max': [],
'auroc_mean': [],
'auroc_std': [],
'auroc_fused': [],
'ef1_min': [],
'ef1_max': [],
'ef1_mean': [],
'ef1_std': [],
'ef1_fused': []
}
target_dirs = list(sorted(data_dir.iterdir()))
for t, target_dir in enumerate(target_dirs):
print(f'{target_dir.stem} ({t+1}/{len(target_dirs)})')
target_output_dir = output_dir / target_dir.stem
if start_idx is not None and t < start_idx:
if not target_output_dir.exists():
print(f'[{target_dir.stem}] No output dir found')
continue
ligands_fps = np.load(target_output_dir / 'ligands_feats.npy')
actives_fps = np.load(target_output_dir / 'actives_feats.npy')
decoys_fps = np.load(target_output_dir / 'decoys_feats.npy')
else:
ligands = read_ligands(data_source, target_dir)
if len(ligands) == 0:
print(f'[{target_dir.stem}] No ligands found or they may be corrupted')
continue
actives_smiles, decoys_smiles = read_actives_decoys(data_source, target_dir)
actives_fps = extract_fingerprints_fn(actives_smiles)
decoys_fps = extract_fingerprints_fn(decoys_smiles)
target_output_dir.mkdir(exist_ok=True, parents=True)
np.save(target_output_dir / 'actives_feats.npy', actives_fps)
np.save(target_output_dir / 'decoys_feats.npy', decoys_fps)
ligands_smiles = [MolToSmiles(ligand) for ligand in ligands]
ligands_fps = extract_fingerprints_fn(ligands_smiles)
ligands_fps = np.stack(ligands_fps)
np.save(target_output_dir / 'ligands_feats.npy', ligands_fps)
labels = np.concatenate([np.ones(len(actives_fps), dtype=bool), np.zeros(len(decoys_fps), dtype=bool)])
actives_fps = actives_fps.astype(np.float32)
decoys_fps = decoys_fps.astype(np.float32)
ligands_fps = ligands_fps.astype(np.float32)
lig2act = similarity_fn(ligands_fps, actives_fps)
lig2dec = similarity_fn(ligands_fps, decoys_fps)
roc_aucs, ef1s = [], []
for i in tqdm(range(lig2act.shape[0])):
roc_auc = roc_auc_score(labels, np.concatenate([lig2act[i], lig2dec[i]]))
roc_aucs.append(roc_auc)
ef1 = ef_score(labels, np.concatenate([lig2act[i], lig2dec[i]]), 0.01)
ef1s.append(ef1)
roc_auc_fused = roc_auc_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)]))
ef1_fused = ef_score(labels, np.concatenate([lig2act.max(axis=0), lig2dec.max(axis=0)]), 0.01)
benchmark['target'].append(target_dir.stem)
benchmark['auroc_min'].append(np.min(roc_aucs))
benchmark['auroc_max'].append(np.max(roc_aucs))
benchmark['auroc_mean'].append(np.mean(roc_aucs))
benchmark['auroc_std'].append(np.std(roc_aucs))
benchmark['auroc_fused'].append(roc_auc_fused)
benchmark['ef1_min'].append(np.min(ef1s))
benchmark['ef1_max'].append(np.max(ef1s))
benchmark['ef1_mean'].append(np.mean(ef1s))
benchmark['ef1_std'].append(np.std(ef1s))
benchmark['ef1_fused'].append(ef1_fused)
print(f'[{target_dir.stem}] ROC AUC min: {np.min(roc_aucs):.3f}, max: {np.max(roc_aucs):.3f}, '
f'mean: {np.mean(roc_aucs):.3f}, std: {np.std(roc_aucs):.3f} fused: {roc_auc_fused:.3f}')
print(f'[{target_dir.stem}] EF1 min: {np.min(ef1s):.3f}, max: {np.max(ef1s):.3f}, '
f'mean: {np.mean(ef1s):.3f}, std: {np.std(ef1s):.3f} fused: {ef1_fused:.3f}')
benchmark = pd.DataFrame(benchmark)
return benchmark
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment