Skip to content

Instantly share code, notes, and snippets.

@dominicrufa
Created December 6, 2021 20:01
Show Gist options
  • Save dominicrufa/8dfe8d865bc2f33a2fe7870aece7cc6c to your computer and use it in GitHub Desktop.
Save dominicrufa/8dfe8d865bc2f33a2fe7870aece7cc6c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "0144c8d4-6bf0-4758-b880-83ffe84e6d84",
"metadata": {},
"source": [
"# Demonstration of how to turn an `OpenMM` `System` object into a `jax` potential energy function.\n",
"We'll also show how the potential energies compare on a vacuum system."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b925b4bb-6b96-4029-9eff-ac572f81f0eb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.\n",
"/home/dominic/anaconda3/envs/aquaregia/lib/python3.9/site-packages/google/colab/data_table.py:30: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" from IPython.utils import traitlets as _traitlets\n",
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"from openmmtools.testsystems import AlanineDipeptideVacuum, AlanineDipeptideExplicit\n",
"from jax import numpy as jnp\n",
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"from simtk import unit, openmm\n",
"import numpy\n",
"from aquaregia.utils import Array, ArrayTree\n",
"from simtk import unit, openmm"
]
},
{
"cell_type": "markdown",
"id": "04242ae5-52d1-4488-9fa7-30144cbbf0e0",
"metadata": {},
"source": [
"## Vacuum\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f84c9300-ad6e-4afd-8b8e-8587bfe916f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[<openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f83582695a0> >, <openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f83db19b6f0> >, <openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f83db19b870> >, <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f835450f390> >, <openmm.openmm.CMMotionRemover; proxy of <Swig Object of type 'OpenMM::CMMotionRemover *' at 0x7f835450f6f0> >]\n"
]
}
],
"source": [
"adp_vacuum = AlanineDipeptideVacuum(constraints=None) # jax simulation doesnt support constraints atm\n",
"print(adp_vacuum.system.getForces())\n",
"adp_vacuum.system.removeForce(4) # remove the CMMotionRemoverForce"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0af7a43a-ce5a-4c3d-8ed0-136ddef553a8",
"metadata": {},
"outputs": [],
"source": [
"from aquaregia.openmm import make_canonical_energy_fn\n",
"import jax_md\n",
"from aquaregia.utils import get_vacuum_neighbor_list\n",
"import jax"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "10b3d407-589b-4fbc-be89-3fed66710f02",
"metadata": {},
"outputs": [],
"source": [
"displacement_fn, shift_fn = jax_md.space.free() #get the displacement and shift fns\n",
"vacuum_neighbor_list = get_vacuum_neighbor_list(num_particles=adp_vacuum.system.getNumParticles())"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "2d347ba1-8506-4242-8751-db7ecc498f52",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"querying nonbonded particles: 100%|██████████████████████████████████████████████████| 22/22 [00:00<00:00, 37632.42it/s]\n",
"querying nonbonded exception particles...: 100%|█████████████████████████████████████| 98/98 [00:00<00:00, 54162.84it/s]\n"
]
}
],
"source": [
"vacuum_params, vacuum_u = make_canonical_energy_fn(system = adp_vacuum.system,\n",
" displacement_fn=displacement_fn,\n",
" ) "
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "58df6737-1097-439b-ad4b-5cf2ff6a57c4",
"metadata": {},
"outputs": [],
"source": [
"jax_posits = Array(adp_vacuum.positions.value_in_unit_system(unit.md_unit_system))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "a8d3b1e6-7977-4b63-922f-ca595d211cb9",
"metadata": {},
"outputs": [],
"source": [
"jax_energy = jax.jit(vacuum_u)(jax_posits, vacuum_neighbor_list, vacuum_params)"
]
},
{
"cell_type": "markdown",
"id": "07ee3e06-2786-4085-afee-cca93f6b3e9e",
"metadata": {},
"source": [
"now get the `OpenMM` Energy"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e5c1b324-7634-490e-8893-9edbee8cdd6b",
"metadata": {},
"outputs": [],
"source": [
"from openmmtools.integrators import LangevinIntegrator"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "d3c1c30f-1ca2-4c84-b83a-2d64722767de",
"metadata": {},
"outputs": [],
"source": [
"integrator = LangevinIntegrator()\n",
"context = openmm.Context(adp_vacuum.system, integrator)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "147a950a-0640-4e02-b6fd-27efc59754aa",
"metadata": {},
"outputs": [],
"source": [
"context.setPositions(adp_vacuum.positions)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "e50c4289-4e3e-4094-a822-39ca246be3c9",
"metadata": {},
"outputs": [],
"source": [
"omm_energy = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)"
]
},
{
"cell_type": "markdown",
"id": "36a9a6bb-023c-4ebf-b6e4-3241705ddcae",
"metadata": {},
"source": [
"let's compute the energy difference"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "bb352d6f-f071-4b73-b41c-b1dab052af61",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(3.78971047e-05, dtype=float64)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"omm_energy - jax_energy"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "8e616183-41b1-47c2-9e29-375222088da4",
"metadata": {},
"outputs": [],
"source": [
"del context"
]
},
{
"cell_type": "markdown",
"id": "85b01037-64e2-42a0-a6c3-5b9b696dbd35",
"metadata": {},
"source": [
"^ in kJ/mol"
]
},
{
"cell_type": "markdown",
"id": "120f5735-ab6a-4776-931a-5f33419a2f9b",
"metadata": {},
"source": [
"## PBCs (NVT) with rectilinear box.\n",
"Let's take a look at the `AlanineDipeptideExplicit` with PBCs. The energy discrepancy will not be insignificant since we are using reaction field and polynomial decay function for the electrostatics and sterics."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "98b53d02-52ab-4531-a778-54d6dbc1fa14",
"metadata": {},
"outputs": [],
"source": [
"adp_explicit = AlanineDipeptideExplicit(constraints=None, rigid_water=False)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "56204e24-cc2a-4594-b6d2-3197ecb0f942",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[<openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f82fc7739f0> >, <openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f82fc7738a0> >, <openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f82fc773c90> >, <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f82fc755c60> >, <openmm.openmm.CMMotionRemover; proxy of <Swig Object of type 'OpenMM::CMMotionRemover *' at 0x7f82ac0406f0> >]\n"
]
}
],
"source": [
"print(adp_explicit.system.getForces())\n",
"adp_explicit.system.removeForce(4)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "5ac259a7-6936-46d3-bb1d-76e02075de53",
"metadata": {},
"outputs": [],
"source": [
"from aquaregia.openmm import get_box_vectors_from_vec3s"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "f7fcc7e4-dc90-47cf-9451-299450d511d8",
"metadata": {},
"outputs": [],
"source": [
"box_vectors = get_box_vectors_from_vec3s(adp_explicit.system.getDefaultPeriodicBoxVectors())"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "6885a53a-11a3-479e-b150-2dab6fdb4ebf",
"metadata": {},
"outputs": [],
"source": [
"pbc_displacement_fn, pbc_shift_fn = jax_md.space.periodic(box_vectors)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "5e53667c-a6a0-4df7-bcf6-4c6337af633c",
"metadata": {},
"outputs": [],
"source": [
"nbr_fns = jax_md.partition.neighbor_list(displacement_or_metric = pbc_displacement_fn,\n",
" box_size = box_vectors,\n",
" r_cutoff = 1., #1nm cutoff\n",
" dr_threshold = 0.2,\n",
" capacity_multiplier = 1.25)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "de62af29-eb0c-4eb2-90c7-746798b6e7e3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"querying nonbonded particles: 100%|██████████████████████████████████████████████| 2269/2269 [00:00<00:00, 77004.85it/s]\n",
"querying nonbonded exception particles...: 100%|█████████████████████████████████| 2345/2345 [00:00<00:00, 63748.18it/s]\n"
]
}
],
"source": [
"explicit_params, explicit_u = make_canonical_energy_fn(system = adp_explicit.system,\n",
" displacement_fn=pbc_displacement_fn,\n",
" ) "
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "cb0aeafb-f50b-41db-a7fc-d3f890c5888a",
"metadata": {},
"outputs": [],
"source": [
"jax_explicit_positions = Array(adp_explicit.positions.value_in_unit_system(unit.md_unit_system))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "3e492dfd-a17f-46e4-adb3-741fd67bd1b1",
"metadata": {},
"outputs": [],
"source": [
"init_neighbor_list = nbr_fns.allocate(jax_explicit_positions)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "799b5da8-5d53-4be8-a318-b452f39fb24c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2269, 856)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"init_neighbor_list.idx.shape"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "3fa0ba0d-319d-4bf2-b6c1-32570c5989f1",
"metadata": {},
"outputs": [],
"source": [
"jax_explicit_energy = jax.jit(explicit_u)(jax_explicit_positions, init_neighbor_list, explicit_params)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "f5e894b2-6e87-48bb-9c04-2f2e656cea33",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(-21838.40004107, dtype=float64)"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax_explicit_energy"
]
},
{
"cell_type": "markdown",
"id": "6d37b471-0eed-49d0-af5a-bec154832f97",
"metadata": {},
"source": [
"get the `OpenMM` energy"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "bdba2014-03e3-47fe-9c98-4eb674c4bfef",
"metadata": {},
"outputs": [],
"source": [
"explicit_integrator = LangevinIntegrator()\n",
"explicit_context = openmm.Context(adp_explicit.system, explicit_integrator)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "ed719fa8-4e48-4b3b-8d96-ce7640992779",
"metadata": {},
"outputs": [],
"source": [
"explicit_context.setPositions(adp_explicit.positions)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "c7dc28fb-14ff-4d00-bddb-9ad99b20c7ed",
"metadata": {},
"outputs": [],
"source": [
"explicit_context.setPeriodicBoxVectors(*adp_explicit.system.getDefaultPeriodicBoxVectors())"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "c4d0baa9-7518-402b-ad2d-5782c2c596c0",
"metadata": {},
"outputs": [],
"source": [
"omm_explicit_energy = explicit_context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "0b816020-ea8f-4d8b-ab3c-1bca13d5e06d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(-2820.75967401, dtype=float64)"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"omm_explicit_energy - jax_explicit_energy"
]
},
{
"cell_type": "markdown",
"id": "d6598fb1-7a28-4e2f-917a-401aa894bec0",
"metadata": {},
"source": [
"so we are discrepant, as we would expect. the `OpenMM` context is using `PME` while we are using a reaction field implementation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e9fd809-3c1e-45e7-b84c-6da2243455d8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment