Skip to content

Instantly share code, notes, and snippets.

@mpds
Last active November 24, 2022 05:06
Show Gist options
  • Save mpds/058d5310f9e6a3b21d353d22599c6760 to your computer and use it in GitHub Desktop.
Save mpds/058d5310f9e6a3b21d353d22599c6760 to your computer and use it in GitHub Desktop.
Represent molecules with voxel grids
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Represent molecules with voxel grids"
],
"metadata": {
"id": "QS2-kcNlhxhN"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TVtH7JnrfogI"
},
"outputs": [],
"source": [
"!pip install numpy\n",
"!pip install biopandas # for parsing the PDB file format\n",
"!pip install plotly"
]
},
{
"cell_type": "code",
"source": [
"from biopandas.pdb import PandasPdb\n",
"import numpy as np\n",
"import plotly.graph_objects as go"
],
"metadata": {
"id": "FFxzdrGzgPXl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"mol = PandasPdb().fetch_pdb('8DDM') # download molecule PDB file\n",
"\n",
"# select only ligand information for this example\n",
"het = mol.df[\"HETATM\"].loc[mol.df[\"HETATM\"][\"chain_id\"] == \"A\"]\n",
"ligand = het.loc[het[\"residue_name\"] == \"K36\"]\n",
"display(ligand)"
],
"metadata": {
"id": "tzNgBIr5griT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# let's represent the molecule using simple properties\n",
"# each grid channel will correspond to the following atoms: \n",
"# carbon, hydrogen, oxygen, nitrogen\n",
"\n",
"def channels_mask(atoms):\n",
" atom_symbs = atoms[\"element_symbol\"].to_list()\n",
" types = [\"C\", \"H\", \"O\", \"N\"]\n",
" mask = np.asarray([np.in1d(atom_symbs, c) for c in types])\n",
" return mask"
],
"metadata": {
"id": "trRwSTbigsj7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# lookup table for vdw radius\n",
"\n",
"def vdw_radius(atom_type):\n",
" vdws = {\n",
" \"C\": 1.7,\n",
" \"H\": 1.1,\n",
" \"O\": 1.52,\n",
" \"N\": 1.55,\n",
" }\n",
"\n",
" return vdws.get(atom_type, 0)"
],
"metadata": {
"id": "TAp45kzzvMeo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# generate a 3-dimensional grid of point coords\n",
"\n",
"def build_grid(box_dims, vox_size):\n",
" dims = np.array(box_dims)\n",
" axis = [np.arange(0, n) * vox_size - m for n, m in zip(dims/vox_size, dims/2)]\n",
" grid = (\n",
" np.repeat(axis[0], axis[1].shape[0] * axis[2].shape[0]),\n",
" np.tile(np.repeat(axis[1], axis[2].shape[0]), axis[0].shape[0]),\n",
" np.tile(axis[2], axis[0].shape[0] * axis[1].shape[0]),\n",
" )\n",
" return grid\n"
],
"metadata": {
"id": "Tfkc_6JSlj3q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# generate voxel features\n",
"\n",
"def voxelize(atoms, channels_mask, box_dims, vox_size):\n",
" coords = atoms[[\"x_coord\", \"y_coord\", \"z_coord\"]].to_numpy() # get atoms coords\n",
" grid_center = coords.mean(axis=0) # molecule geometrical center as the grid center\n",
" grid = build_grid(box_dims, vox_size)\n",
" # translate grid points and reshape to perform broadcasting\n",
" tgrid = [(u + v)[..., np.newaxis] for u, v in zip(grid, grid_center)]\n",
" \n",
" # calculate voxel occupancies\n",
" dist = np.sqrt(\n",
" np.power((coords[:, 0] - tgrid[0]), 2)\n",
" + np.power((coords[:, 1] - tgrid[1]), 2)\n",
" + np.power((coords[:, 2] - tgrid[2]), 2)\n",
" )\n",
" \n",
" vdws = np.array([vdw_radius(atype) for atype in atoms[\"element_symbol\"]])\n",
" occs = 1 - np.exp(-1 * np.power(vdws / dist, 12))\n",
"\n",
" voxel_grid = np.zeros((channels_mask.shape[0], grid[0].shape[0]))\n",
" for i, channel in enumerate(channels_mask):\n",
" if np.any(channel):\n",
" np.amax(occs[:, channel], axis=1, out=voxel_grid[i])\n",
" \n",
" # reshape to (n_channels, dim1, dim2, dim3)\n",
" return voxel_grid.reshape(\n",
" [channels_mask.shape[0], *[int(dim/vox_size) for dim in box_dims]]\n",
" )\n",
" "
],
"metadata": {
"id": "KSc91pW5oC7Y"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# plot the voxel grids\n",
"\n",
"def plot(grid_channel, grid):\n",
" fig = go.Figure(\n",
" data=go.Volume(\n",
" x=grid[0],\n",
" y=grid[1],\n",
" z=grid[2],\n",
" value=grid_channel.flatten(),\n",
" isomin=0.001,\n",
" isomax=1.,\n",
" opacity=0.12, # needs to be small to see through all surfaces\n",
" surface_count=18, # needs to be a large number for good volume rendering\n",
" )\n",
" )\n",
" fig.show()"
],
"metadata": {
"id": "LenSXI8-zk1o"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"box_dims, vox_size = [24, 24, 24], 0.5\n",
"grid = build_grid(box_dims, vox_size)\n",
"voxel_grid = voxelize(ligand, channels_mask(ligand), box_dims, vox_size)"
],
"metadata": {
"id": "h4gl4Hnz5tij"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# plot carbon channel (1st channel)\n",
"# obs: the plot is interactive!\n",
"plot(voxel_grid[0], grid)"
],
"metadata": {
"id": "KwKVVNTQkilb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# plot oxygen channel (3rd channel)\n",
"plot(voxel_grid[2], grid)"
],
"metadata": {
"id": "UtebogimkkkY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# plot nitrogen channel (4th channel)\n",
"plot(voxel_grid[3], grid)"
],
"metadata": {
"id": "dV7M7AJVuCLE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "bbxnpKjE6Yqs"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment