Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save maxentile/fa8d173b0217370c3866bf150e85056d to your computer and use it in GitHub Desktop.
Save maxentile/fa8d173b0217370c3866bf150e85056d to your computer and use it in GitHub Desktop.
try to assign unique labels to each unique-up-to-symmetry atom in a partially symmetric molecule
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# atoms: 25, # symmetry classes: 15\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"RDKit WARNING: [23:09:46] Enabling RDKit 2019.09.3 jupyter extensions\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<rdkit.Chem.rdchem.Mol at 0x12edc7430>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import dgl\n",
"from dgl.nn.pytorch import GraphConv, GINConv, GATConv, TAGConv\n",
"import torch.nn.functional as F\n",
"from torch import nn\n",
"\n",
"class TAG(nn.Module):\n",
" def __init__(self, in_feats, h_feats, num_classes, k=2):\n",
" super(TAG, self).__init__()\n",
" self.layer1 = TAGConv(in_feats, h_feats, k)\n",
" self.layer2 = TAGConv(h_feats, h_feats, k)\n",
" self.layer3 = TAGConv(h_feats, num_classes, k)\n",
" \n",
" def forward(self, graph, inputs):\n",
" h = self.layer1(graph, inputs)\n",
" h = F.relu(h)\n",
" h = self.layer2(graph, h)\n",
" h = F.relu(h)\n",
" h = self.layer3(graph, h)\n",
" return h\n",
" \n",
"from sklearn.preprocessing import OneHotEncoder\n",
"def construct_node_classification_targets(mol):\n",
" oemol = mol.to_openeye()\n",
" oechem.OEPerceiveSymmetry(oemol)\n",
" symmetry_classes = np.array([atom.GetSymmetryClass() for atom in oemol.GetAtoms()])\n",
" \n",
" ohe = OneHotEncoder(sparse=False)\n",
" return ohe.fit_transform(symmetry_classes.reshape(-1, 1))\n",
"\n",
"from openeye import oechem\n",
"from openforcefield.topology import Molecule\n",
"import numpy as np\n",
"\n",
"smiles = 'NC(N1)=NC2(CCCCC2)C1=O'\n",
"mol = Molecule.from_smiles(smiles, hydrogens_are_explicit=False, allow_undefined_stereo=True)\n",
"\n",
"oemol = mol.to_openeye()\n",
"oechem.OEPerceiveSymmetry(oemol)\n",
"symmetry_classes = np.array([atom.GetSymmetryClass() for atom in oemol.GetAtoms()])\n",
"print('# atoms: {}, # symmetry classes: {}'.format(len(symmetry_classes), len(set(symmetry_classes))))\n",
"\n",
"from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator\n",
"\n",
"radius = 5\n",
"fpSize = 512\n",
"morgan_generator = GetMorganGenerator(radius=radius, fpSize=fpSize)\n",
"\n",
"def compute_atom_centered_morgan_fingerprints(rdmol, morgan_generator, fpSize):\n",
" n_atoms = rdmol.GetNumAtoms()\n",
" fingerprints = np.zeros((n_atoms, fpSize), dtype=int)\n",
"\n",
" for i in range(rdmol.GetNumAtoms()):\n",
" fingerprint = morgan_generator.GetCountFingerprint(rdmol, fromAtoms=[i])\n",
" for (key, val) in fingerprint.GetNonzeroElements().items():\n",
" fingerprints[i, key] = val\n",
" return fingerprints\n",
"\n",
"\n",
"rdmol = mol.to_rdkit()\n",
"fingerprints = compute_atom_centered_morgan_fingerprints(rdmol, morgan_generator, fpSize)\n",
"y = construct_node_classification_targets(mol)\n",
"\n",
"input_dim = fingerprints.shape[1]\n",
"hidden_dim = 100\n",
"output_dim = y.shape[1]\n",
"net = TAG(input_dim, hidden_dim, output_dim)\n",
"mol.to_rdkit()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DGLGraph(num_nodes=25, num_edges=52,\n",
" ndata_schemes={}\n",
" edata_schemes={})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph = dgl.DGLGraph()\n",
"graph.from_networkx(mol.to_networkx())\n",
"graph"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"input_tensor = torch.Tensor(fingerprints)\n",
"output_tensor = torch.Tensor(y)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [00:09<00:00, 102.06it/s, accuracy=0.92]\n"
]
}
],
"source": [
"optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
"torch_target = torch.argmax(output_tensor, dim=1)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"\n",
"from tqdm import tqdm\n",
"trange = tqdm(range(1000))\n",
"predictions = []\n",
"accuracy_traj = []\n",
"loss_traj = []\n",
"for epoch in trange:\n",
" net.train()\n",
" logits = net(graph, input_tensor)\n",
" loss = loss_fn.forward(F.softmax(logits, dim=1), torch_target)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" loss_traj.append(float(loss.detach().numpy()))\n",
" \n",
" pred = torch.argmax(logits, dim=1)\n",
" predictions.append(pred.numpy())\n",
" \n",
" accuracy = float(torch.sum(pred == torch_target)) / len(input_tensor)\n",
" accuracy_traj.append(accuracy)\n",
" \n",
" trange.set_postfix(accuracy=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'cross-entropy loss')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(loss_traj)\n",
"plt.xlabel('iteration')\n",
"plt.xscale('log')\n",
"plt.ylabel('cross-entropy loss')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'accuracy')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(accuracy_traj)\n",
"plt.hlines(1.0,0,len(accuracy_traj), linestyles='--')\n",
"plt.xscale('log')\n",
"plt.xlabel('iteration')\n",
"plt.ylabel('accuracy')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:dgl]",
"language": "python",
"name": "conda-env-dgl-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment