Skip to content

Instantly share code, notes, and snippets.

@iwatobipen
Created April 26, 2020 14:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save iwatobipen/017fae146c1d67a2cfee455c0a427330 to your computer and use it in GitHub Desktop.
Save iwatobipen/017fae146c1d67a2cfee455c0a427330 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using backend: pytorch\n"
]
}
],
"source": [
"%matplotlib inline \n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"from rdkit import Chem\n",
"from rdkit.Chem import rdmolops, rdmolfiles\n",
"from rdkit import RDPaths\n",
" \n",
"import dgl\n",
"import numpy as np\n",
"import random\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data import Dataset\n",
"from dgl import model_zoo\n",
"from dgllife.model import AttentiveFPPredictor\n",
"from dgllife.utils import mol_to_complete_graph, mol_to_bigraph\n",
"from dgllife.utils import atom_type_one_hot\n",
"from dgllife.utils import atom_degree_one_hot\n",
"from dgllife.utils import atom_formal_charge\n",
"from dgllife.utils import atom_num_radical_electrons\n",
"from dgllife.utils import atom_hybridization_one_hot\n",
"from dgllife.utils import atom_total_num_H_one_hot\n",
"from dgllife.utils import one_hot_encoding\n",
"from dgllife.utils import CanonicalAtomFeaturizer\n",
"from dgllife.utils import CanonicalBondFeaturizer\n",
"from dgllife.utils import ConcatFeaturizer\n",
"from dgllife.utils import BaseAtomFeaturizer\n",
"from dgllife.utils import BaseBondFeaturizer\n",
"from dgllife.utils import one_hot_encoding \n",
"from dgl.data.utils import split_dataset\n",
" \n",
"from functools import partial\n",
"from sklearn.metrics import roc_auc_score\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model = AttentiveFPPredictor(node_feat_size=39,\n",
" edge_feat_size=10,\n",
" num_layers=2,\n",
" num_timesteps=2,\n",
" graph_feat_size=200,\n",
" n_tasks=1,\n",
" dropout=0.2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1/100, training 7.6881\n",
"epoch 2/100, training 4.9003\n",
"epoch 3/100, training 2.9593\n",
"epoch 4/100, training 3.5842\n",
"epoch 5/100, training 3.0182\n",
"epoch 6/100, training 2.7759\n",
"epoch 7/100, training 2.3658\n",
"epoch 8/100, training 1.9188\n"
]
}
],
"source": [
"def chirality(atom):\n",
" try:\n",
" return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \\\n",
" [atom.HasProp('_ChiralityPossible')]\n",
" except:\n",
" return [False, False] + [atom.HasProp('_ChiralityPossible')]\n",
" \n",
"def collate_molgraphs(data):\n",
" \"\"\"Batching a list of datapoints for dataloader.\n",
" Parameters\n",
" ----------\n",
" data : list of 3-tuples or 4-tuples.\n",
" Each tuple is for a single datapoint, consisting of\n",
" a SMILES, a DGLGraph, all-task labels and optionally\n",
" a binary mask indicating the existence of labels.\n",
" Returns\n",
" -------\n",
" smiles : list\n",
" List of smiles\n",
" bg : BatchedDGLGraph\n",
" Batched DGLGraphs\n",
" labels : Tensor of dtype float32 and shape (B, T)\n",
" Batched datapoint labels. B is len(data) and\n",
" T is the number of total tasks.\n",
" masks : Tensor of dtype float32 and shape (B, T)\n",
" Batched datapoint binary mask, indicating the\n",
" existence of labels. If binary masks are not\n",
" provided, return a tensor with ones.\n",
" \"\"\"\n",
" assert len(data[0]) in [3, 4], \\\n",
" 'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))\n",
" if len(data[0]) == 3:\n",
" smiles, graphs, labels = map(list, zip(*data))\n",
" masks = None\n",
" else:\n",
" smiles, graphs, labels, masks = map(list, zip(*data))\n",
" \n",
" bg = dgl.batch(graphs)\n",
" bg.set_n_initializer(dgl.init.zero_initializer)\n",
" bg.set_e_initializer(dgl.init.zero_initializer)\n",
" labels = torch.stack(labels, dim=0)\n",
" \n",
" if masks is None:\n",
" masks = torch.ones(labels.shape)\n",
" else:\n",
" masks = torch.stack(masks, dim=0)\n",
" return smiles, bg, labels, masks\n",
" \n",
"atom_featurizer = BaseAtomFeaturizer(\n",
" {'hv': ConcatFeaturizer([\n",
" partial(atom_type_one_hot, allowable_set=[\n",
" 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],\n",
" encode_unknown=True),\n",
" partial(atom_degree_one_hot, allowable_set=list(range(6))),\n",
" atom_formal_charge, atom_num_radical_electrons,\n",
" partial(atom_hybridization_one_hot, encode_unknown=True),\n",
" lambda atom: [0], # A placeholder for aromatic information,\n",
" atom_total_num_H_one_hot, chirality\n",
" ],\n",
" )})\n",
"bond_featurizer = BaseBondFeaturizer({\n",
" 'he': lambda bond: [0 for _ in range(10)]\n",
" })\n",
" \n",
"train=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.train.sdf')\n",
"test=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf')\n",
"\n",
"train_mols = Chem.SDMolSupplier(train)\n",
"train_smi =[Chem.MolToSmiles(m) for m in train_mols]\n",
"train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1)\n",
"\n",
"test_mols = Chem.SDMolSupplier(test)\n",
"test_smi = [Chem.MolToSmiles(m) for m in test_mols]\n",
"test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1)\n",
"\n",
"train_graph =[mol_to_bigraph(mol,\n",
" node_featurizer=atom_featurizer, \n",
" edge_featurizer=bond_featurizer) for mol in train_mols]\n",
" \n",
"test_graph =[mol_to_bigraph(mol,\n",
" node_featurizer=atom_featurizer, \n",
" edge_featurizer=bond_featurizer) for mol in test_mols]\n",
" \n",
"def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):\n",
" model.train()\n",
" total_loss = 0\n",
" losses = []\n",
" \n",
" for batch_id, batch_data in enumerate(data_loader):\n",
" batch_data\n",
" smiles, bg, labels, masks = batch_data\n",
" if torch.cuda.is_available():\n",
" bg.to(torch.device('cuda:0'))\n",
" labels = labels.to('cuda:0')\n",
" masks = masks.to('cuda:0')\n",
" \n",
" prediction = model(bg, bg.ndata['hv'], bg.edata['he'])\n",
" loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()\n",
" #loss = loss_criterion(prediction, labels)\n",
" #print(loss.shape)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" losses.append(loss.data.item())\n",
" \n",
" #total_score = np.mean(train_meter.compute_metric('rmse'))\n",
" total_score = np.mean(losses)\n",
" print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs, total_score))\n",
" return total_score\n",
" \n",
"model = AttentiveFPPredictor(node_feat_size=39,\n",
" edge_feat_size=10,\n",
" num_layers=2,\n",
" num_timesteps=2,\n",
" graph_feat_size=200,\n",
" n_tasks=1,\n",
" dropout=0.2)\n",
"model = model.to('cuda:0')\n",
" \n",
"train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)\n",
"test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)\n",
" \n",
"loss_fn = nn.MSELoss(reduction='none')\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)\n",
"n_epochs = 100\n",
"epochs = []\n",
"scores = []\n",
"for e in range(n_epochs):\n",
" score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)\n",
" epochs.append(e)\n",
" scores.append(score)\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(epochs, scores)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"from rdkit.Chem import rdDepictor\n",
"from rdkit.Chem.Draw import rdMolDraw2D\n",
"from IPython.display import SVG\n",
"from IPython.display import display\n",
"import matplotlib\n",
"import matplotlib.cm as cm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def drawmol(idx, dataset, timestep):\n",
" smiles, graph, _ = dataset[idx]\n",
" print(smiles)\n",
" bg = dgl.batch([graph])\n",
" atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']\n",
" if torch.cuda.is_available():\n",
" print('use cuda')\n",
" bg.to(torch.device('cuda:0'))\n",
" atom_feats = atom_feats.to('cuda:0')\n",
" bond_feats = bond_feats.to('cuda:0')\n",
" \n",
" _, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)\n",
" assert timestep < len(atom_weights), 'Unexpected id for the readout round'\n",
" atom_weights = atom_weights[timestep]\n",
" min_value = torch.min(atom_weights)\n",
" max_value = torch.max(atom_weights)\n",
" atom_weights = (atom_weights - min_value) / (max_value - min_value)\n",
" \n",
" norm = matplotlib.colors.Normalize(vmin=0, vmax=1.28)\n",
" cmap = cm.get_cmap('bwr')\n",
" plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap)\n",
" atom_colors = {i: plt_colors.to_rgba(atom_weights[i].data.item()) for i in range(bg.number_of_nodes())}\n",
" \n",
" mol = Chem.MolFromSmiles(smiles)\n",
" rdDepictor.Compute2DCoords(mol)\n",
" drawer = rdMolDraw2D.MolDraw2DSVG(280, 280)\n",
" drawer.SetFontSize(1)\n",
" op = drawer.drawOptions()\n",
" \n",
" mol = rdMolDraw2D.PrepareMolForDrawing(mol)\n",
" drawer.DrawMolecule(mol, highlightAtoms=range(bg.number_of_nodes()),\n",
" highlightBonds=[],\n",
" highlightAtomColors=atom_colors)\n",
" drawer.FinishDrawing()\n",
" svg = drawer.GetDrawingText()\n",
" svg = svg.replace('svg:', '')\n",
" if torch.cuda.is_available():\n",
" atom_weights = atom_weights.to('cpu')\n",
" \n",
" a = np.array([[0,1]])\n",
" plt.figure(figsize=(9, 1.5))\n",
" img = plt.imshow(a, cmap=\"bwr\")\n",
" plt.gca().set_visible(False)\n",
" cax = plt.axes([0.1, 0.2, 0.8, 0.2])\n",
" plt.colorbar(orientation='horizontal', cax=cax)\n",
" plt.show()\n",
" return (Chem.MolFromSmiles(smiles), atom_weights.data.numpy(), svg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target = test_loader.dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(len(target))[:5]:\n",
" mol, aw, svg = drawmol(i, target, 0)\n",
" print(aw.min(), aw.max())\n",
" display(SVG(svg))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment