Forked from iwatobipen/AttentiveFP with color bar.ipynb
Created
February 17, 2022 13:51
-
-
Save leelasd/7c4a6a3abcbfeb91440d78473cb3aaff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using backend: pytorch\n" | |
] | |
} | |
], | |
"source": [ | |
"%matplotlib inline \n", | |
"import matplotlib.pyplot as plt\n", | |
"import os\n", | |
"from rdkit import Chem\n", | |
"from rdkit.Chem import rdmolops, rdmolfiles\n", | |
"from rdkit import RDPaths\n", | |
" \n", | |
"import dgl\n", | |
"import numpy as np\n", | |
"import random\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"from torch.utils.data import DataLoader\n", | |
"from torch.utils.data import Dataset\n", | |
"from dgl import model_zoo\n", | |
"from dgllife.model import AttentiveFPPredictor\n", | |
"from dgllife.utils import mol_to_complete_graph, mol_to_bigraph\n", | |
"from dgllife.utils import atom_type_one_hot\n", | |
"from dgllife.utils import atom_degree_one_hot\n", | |
"from dgllife.utils import atom_formal_charge\n", | |
"from dgllife.utils import atom_num_radical_electrons\n", | |
"from dgllife.utils import atom_hybridization_one_hot\n", | |
"from dgllife.utils import atom_total_num_H_one_hot\n", | |
"from dgllife.utils import one_hot_encoding\n", | |
"from dgllife.utils import CanonicalAtomFeaturizer\n", | |
"from dgllife.utils import CanonicalBondFeaturizer\n", | |
"from dgllife.utils import ConcatFeaturizer\n", | |
"from dgllife.utils import BaseAtomFeaturizer\n", | |
"from dgllife.utils import BaseBondFeaturizer\n", | |
"from dgllife.utils import one_hot_encoding \n", | |
"from dgl.data.utils import split_dataset\n", | |
" \n", | |
"from functools import partial\n", | |
"from sklearn.metrics import roc_auc_score\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = AttentiveFPPredictor(node_feat_size=39,\n", | |
" edge_feat_size=10,\n", | |
" num_layers=2,\n", | |
" num_timesteps=2,\n", | |
" graph_feat_size=200,\n", | |
" n_tasks=1,\n", | |
" dropout=0.2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch 1/100, training 7.6881\n", | |
"epoch 2/100, training 4.9003\n", | |
"epoch 3/100, training 2.9593\n", | |
"epoch 4/100, training 3.5842\n", | |
"epoch 5/100, training 3.0182\n", | |
"epoch 6/100, training 2.7759\n", | |
"epoch 7/100, training 2.3658\n", | |
"epoch 8/100, training 1.9188\n" | |
] | |
} | |
], | |
"source": [ | |
"def chirality(atom):\n", | |
" try:\n", | |
" return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \\\n", | |
" [atom.HasProp('_ChiralityPossible')]\n", | |
" except:\n", | |
" return [False, False] + [atom.HasProp('_ChiralityPossible')]\n", | |
" \n", | |
"def collate_molgraphs(data):\n", | |
" \"\"\"Batching a list of datapoints for dataloader.\n", | |
" Parameters\n", | |
" ----------\n", | |
" data : list of 3-tuples or 4-tuples.\n", | |
" Each tuple is for a single datapoint, consisting of\n", | |
" a SMILES, a DGLGraph, all-task labels and optionally\n", | |
" a binary mask indicating the existence of labels.\n", | |
" Returns\n", | |
" -------\n", | |
" smiles : list\n", | |
" List of smiles\n", | |
" bg : BatchedDGLGraph\n", | |
" Batched DGLGraphs\n", | |
" labels : Tensor of dtype float32 and shape (B, T)\n", | |
" Batched datapoint labels. B is len(data) and\n", | |
" T is the number of total tasks.\n", | |
" masks : Tensor of dtype float32 and shape (B, T)\n", | |
" Batched datapoint binary mask, indicating the\n", | |
" existence of labels. If binary masks are not\n", | |
" provided, return a tensor with ones.\n", | |
" \"\"\"\n", | |
" assert len(data[0]) in [3, 4], \\\n", | |
" 'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))\n", | |
" if len(data[0]) == 3:\n", | |
" smiles, graphs, labels = map(list, zip(*data))\n", | |
" masks = None\n", | |
" else:\n", | |
" smiles, graphs, labels, masks = map(list, zip(*data))\n", | |
" \n", | |
" bg = dgl.batch(graphs)\n", | |
" bg.set_n_initializer(dgl.init.zero_initializer)\n", | |
" bg.set_e_initializer(dgl.init.zero_initializer)\n", | |
" labels = torch.stack(labels, dim=0)\n", | |
" \n", | |
" if masks is None:\n", | |
" masks = torch.ones(labels.shape)\n", | |
" else:\n", | |
" masks = torch.stack(masks, dim=0)\n", | |
" return smiles, bg, labels, masks\n", | |
" \n", | |
"atom_featurizer = BaseAtomFeaturizer(\n", | |
" {'hv': ConcatFeaturizer([\n", | |
" partial(atom_type_one_hot, allowable_set=[\n", | |
" 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],\n", | |
" encode_unknown=True),\n", | |
" partial(atom_degree_one_hot, allowable_set=list(range(6))),\n", | |
" atom_formal_charge, atom_num_radical_electrons,\n", | |
" partial(atom_hybridization_one_hot, encode_unknown=True),\n", | |
" lambda atom: [0], # A placeholder for aromatic information,\n", | |
" atom_total_num_H_one_hot, chirality\n", | |
" ],\n", | |
" )})\n", | |
"bond_featurizer = BaseBondFeaturizer({\n", | |
" 'he': lambda bond: [0 for _ in range(10)]\n", | |
" })\n", | |
" \n", | |
"train=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.train.sdf')\n", | |
"test=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf')\n", | |
"\n", | |
"train_mols = Chem.SDMolSupplier(train)\n", | |
"train_smi =[Chem.MolToSmiles(m) for m in train_mols]\n", | |
"train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1)\n", | |
"\n", | |
"test_mols = Chem.SDMolSupplier(test)\n", | |
"test_smi = [Chem.MolToSmiles(m) for m in test_mols]\n", | |
"test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1)\n", | |
"\n", | |
"train_graph =[mol_to_bigraph(mol,\n", | |
" node_featurizer=atom_featurizer, \n", | |
" edge_featurizer=bond_featurizer) for mol in train_mols]\n", | |
" \n", | |
"test_graph =[mol_to_bigraph(mol,\n", | |
" node_featurizer=atom_featurizer, \n", | |
" edge_featurizer=bond_featurizer) for mol in test_mols]\n", | |
" \n", | |
"def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):\n", | |
" model.train()\n", | |
" total_loss = 0\n", | |
" losses = []\n", | |
" \n", | |
" for batch_id, batch_data in enumerate(data_loader):\n", | |
" batch_data\n", | |
" smiles, bg, labels, masks = batch_data\n", | |
" if torch.cuda.is_available():\n", | |
" bg.to(torch.device('cuda:0'))\n", | |
" labels = labels.to('cuda:0')\n", | |
" masks = masks.to('cuda:0')\n", | |
" \n", | |
" prediction = model(bg, bg.ndata['hv'], bg.edata['he'])\n", | |
" loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()\n", | |
" #loss = loss_criterion(prediction, labels)\n", | |
" #print(loss.shape)\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" losses.append(loss.data.item())\n", | |
" \n", | |
" #total_score = np.mean(train_meter.compute_metric('rmse'))\n", | |
" total_score = np.mean(losses)\n", | |
" print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs, total_score))\n", | |
" return total_score\n", | |
" \n", | |
"model = AttentiveFPPredictor(node_feat_size=39,\n", | |
" edge_feat_size=10,\n", | |
" num_layers=2,\n", | |
" num_timesteps=2,\n", | |
" graph_feat_size=200,\n", | |
" n_tasks=1,\n", | |
" dropout=0.2)\n", | |
"model = model.to('cuda:0')\n", | |
" \n", | |
"train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)\n", | |
"test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)\n", | |
" \n", | |
"loss_fn = nn.MSELoss(reduction='none')\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)\n", | |
"n_epochs = 100\n", | |
"epochs = []\n", | |
"scores = []\n", | |
"for e in range(n_epochs):\n", | |
" score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)\n", | |
" epochs.append(e)\n", | |
" scores.append(score)\n", | |
"model.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"plt.plot(epochs, scores)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import copy\n", | |
"from rdkit.Chem import rdDepictor\n", | |
"from rdkit.Chem.Draw import rdMolDraw2D\n", | |
"from IPython.display import SVG\n", | |
"from IPython.display import display\n", | |
"import matplotlib\n", | |
"import matplotlib.cm as cm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def drawmol(idx, dataset, timestep):\n", | |
" smiles, graph, _ = dataset[idx]\n", | |
" print(smiles)\n", | |
" bg = dgl.batch([graph])\n", | |
" atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']\n", | |
" if torch.cuda.is_available():\n", | |
" print('use cuda')\n", | |
" bg.to(torch.device('cuda:0'))\n", | |
" atom_feats = atom_feats.to('cuda:0')\n", | |
" bond_feats = bond_feats.to('cuda:0')\n", | |
" \n", | |
" _, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)\n", | |
" assert timestep < len(atom_weights), 'Unexpected id for the readout round'\n", | |
" atom_weights = atom_weights[timestep]\n", | |
" min_value = torch.min(atom_weights)\n", | |
" max_value = torch.max(atom_weights)\n", | |
" atom_weights = (atom_weights - min_value) / (max_value - min_value)\n", | |
" \n", | |
" norm = matplotlib.colors.Normalize(vmin=0, vmax=1.28)\n", | |
" cmap = cm.get_cmap('bwr')\n", | |
" plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap)\n", | |
" atom_colors = {i: plt_colors.to_rgba(atom_weights[i].data.item()) for i in range(bg.number_of_nodes())}\n", | |
" \n", | |
" mol = Chem.MolFromSmiles(smiles)\n", | |
" rdDepictor.Compute2DCoords(mol)\n", | |
" drawer = rdMolDraw2D.MolDraw2DSVG(280, 280)\n", | |
" drawer.SetFontSize(1)\n", | |
" op = drawer.drawOptions()\n", | |
" \n", | |
" mol = rdMolDraw2D.PrepareMolForDrawing(mol)\n", | |
" drawer.DrawMolecule(mol, highlightAtoms=range(bg.number_of_nodes()),\n", | |
" highlightBonds=[],\n", | |
" highlightAtomColors=atom_colors)\n", | |
" drawer.FinishDrawing()\n", | |
" svg = drawer.GetDrawingText()\n", | |
" svg = svg.replace('svg:', '')\n", | |
" if torch.cuda.is_available():\n", | |
" atom_weights = atom_weights.to('cpu')\n", | |
" \n", | |
" a = np.array([[0,1]])\n", | |
" plt.figure(figsize=(9, 1.5))\n", | |
" img = plt.imshow(a, cmap=\"bwr\")\n", | |
" plt.gca().set_visible(False)\n", | |
" cax = plt.axes([0.1, 0.2, 0.8, 0.2])\n", | |
" plt.colorbar(orientation='horizontal', cax=cax)\n", | |
" plt.show()\n", | |
" return (Chem.MolFromSmiles(smiles), atom_weights.data.numpy(), svg)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"target = test_loader.dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for i in range(len(target))[:5]:\n", | |
" mol, aw, svg = drawmol(i, target, 0)\n", | |
" print(aw.min(), aw.max())\n", | |
" display(SVG(svg))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment