Created
December 6, 2021 20:01
-
-
Save dominicrufa/8dfe8d865bc2f33a2fe7870aece7cc6c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "0144c8d4-6bf0-4758-b880-83ffe84e6d84", | |
"metadata": {}, | |
"source": [ | |
"# Demonstration of how to turn an `OpenMM` `System` object into a `jax` potential energy function.\n", | |
"We'll also show how the potential energies compare on a vacuum system." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "b925b4bb-6b96-4029-9eff-ac572f81f0eb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.\n", | |
"/home/dominic/anaconda3/envs/aquaregia/lib/python3.9/site-packages/google/colab/data_table.py:30: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n", | |
" from IPython.utils import traitlets as _traitlets\n", | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
} | |
], | |
"source": [ | |
"from openmmtools.testsystems import AlanineDipeptideVacuum, AlanineDipeptideExplicit\n", | |
"from jax import numpy as jnp\n", | |
"from jax.config import config\n", | |
"config.update(\"jax_enable_x64\", True)\n", | |
"from simtk import unit, openmm\n", | |
"import numpy\n", | |
"from aquaregia.utils import Array, ArrayTree\n", | |
"from simtk import unit, openmm" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "04242ae5-52d1-4488-9fa7-30144cbbf0e0", | |
"metadata": {}, | |
"source": [ | |
"## Vacuum\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "f84c9300-ad6e-4afd-8b8e-8587bfe916f2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[<openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f83582695a0> >, <openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f83db19b6f0> >, <openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f83db19b870> >, <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f835450f390> >, <openmm.openmm.CMMotionRemover; proxy of <Swig Object of type 'OpenMM::CMMotionRemover *' at 0x7f835450f6f0> >]\n" | |
] | |
} | |
], | |
"source": [ | |
"adp_vacuum = AlanineDipeptideVacuum(constraints=None) # jax simulation doesnt support constraints atm\n", | |
"print(adp_vacuum.system.getForces())\n", | |
"adp_vacuum.system.removeForce(4) # remove the CMMotionRemoverForce" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "0af7a43a-ce5a-4c3d-8ed0-136ddef553a8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from aquaregia.openmm import make_canonical_energy_fn\n", | |
"import jax_md\n", | |
"from aquaregia.utils import get_vacuum_neighbor_list\n", | |
"import jax" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "10b3d407-589b-4fbc-be89-3fed66710f02", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"displacement_fn, shift_fn = jax_md.space.free() #get the displacement and shift fns\n", | |
"vacuum_neighbor_list = get_vacuum_neighbor_list(num_particles=adp_vacuum.system.getNumParticles())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "2d347ba1-8506-4242-8751-db7ecc498f52", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"querying nonbonded particles: 100%|██████████████████████████████████████████████████| 22/22 [00:00<00:00, 37632.42it/s]\n", | |
"querying nonbonded exception particles...: 100%|█████████████████████████████████████| 98/98 [00:00<00:00, 54162.84it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"vacuum_params, vacuum_u = make_canonical_energy_fn(system = adp_vacuum.system,\n", | |
" displacement_fn=displacement_fn,\n", | |
" ) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "58df6737-1097-439b-ad4b-5cf2ff6a57c4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"jax_posits = Array(adp_vacuum.positions.value_in_unit_system(unit.md_unit_system))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "a8d3b1e6-7977-4b63-922f-ca595d211cb9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"jax_energy = jax.jit(vacuum_u)(jax_posits, vacuum_neighbor_list, vacuum_params)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "07ee3e06-2786-4085-afee-cca93f6b3e9e", | |
"metadata": {}, | |
"source": [ | |
"now get the `OpenMM` Energy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "e5c1b324-7634-490e-8893-9edbee8cdd6b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from openmmtools.integrators import LangevinIntegrator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "d3c1c30f-1ca2-4c84-b83a-2d64722767de", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"integrator = LangevinIntegrator()\n", | |
"context = openmm.Context(adp_vacuum.system, integrator)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "147a950a-0640-4e02-b6fd-27efc59754aa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"context.setPositions(adp_vacuum.positions)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "e50c4289-4e3e-4094-a822-39ca246be3c9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"omm_energy = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "36a9a6bb-023c-4ebf-b6e4-3241705ddcae", | |
"metadata": {}, | |
"source": [ | |
"let's compute the energy difference" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "bb352d6f-f071-4b73-b41c-b1dab052af61", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray(3.78971047e-05, dtype=float64)" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"omm_energy - jax_energy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "8e616183-41b1-47c2-9e29-375222088da4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"del context" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "85b01037-64e2-42a0-a6c3-5b9b696dbd35", | |
"metadata": {}, | |
"source": [ | |
"^ in kJ/mol" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "120f5735-ab6a-4776-931a-5f33419a2f9b", | |
"metadata": {}, | |
"source": [ | |
"## PBCs (NVT) with rectilinear box.\n", | |
"Let's take a look at the `AlanineDipeptideExplicit` with PBCs. The energy discrepancy will not be insignificant since we are using reaction field and polynomial decay function for the electrostatics and sterics." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "98b53d02-52ab-4531-a778-54d6dbc1fa14", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"adp_explicit = AlanineDipeptideExplicit(constraints=None, rigid_water=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "56204e24-cc2a-4594-b6d2-3197ecb0f942", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[<openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f82fc7739f0> >, <openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f82fc7738a0> >, <openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f82fc773c90> >, <openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f82fc755c60> >, <openmm.openmm.CMMotionRemover; proxy of <Swig Object of type 'OpenMM::CMMotionRemover *' at 0x7f82ac0406f0> >]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(adp_explicit.system.getForces())\n", | |
"adp_explicit.system.removeForce(4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "5ac259a7-6936-46d3-bb1d-76e02075de53", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from aquaregia.openmm import get_box_vectors_from_vec3s" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "f7fcc7e4-dc90-47cf-9451-299450d511d8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"box_vectors = get_box_vectors_from_vec3s(adp_explicit.system.getDefaultPeriodicBoxVectors())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "6885a53a-11a3-479e-b150-2dab6fdb4ebf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pbc_displacement_fn, pbc_shift_fn = jax_md.space.periodic(box_vectors)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "5e53667c-a6a0-4df7-bcf6-4c6337af633c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"nbr_fns = jax_md.partition.neighbor_list(displacement_or_metric = pbc_displacement_fn,\n", | |
" box_size = box_vectors,\n", | |
" r_cutoff = 1., #1nm cutoff\n", | |
" dr_threshold = 0.2,\n", | |
" capacity_multiplier = 1.25)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "de62af29-eb0c-4eb2-90c7-746798b6e7e3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"querying nonbonded particles: 100%|██████████████████████████████████████████████| 2269/2269 [00:00<00:00, 77004.85it/s]\n", | |
"querying nonbonded exception particles...: 100%|█████████████████████████████████| 2345/2345 [00:00<00:00, 63748.18it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"explicit_params, explicit_u = make_canonical_energy_fn(system = adp_explicit.system,\n", | |
" displacement_fn=pbc_displacement_fn,\n", | |
" ) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "cb0aeafb-f50b-41db-a7fc-d3f890c5888a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"jax_explicit_positions = Array(adp_explicit.positions.value_in_unit_system(unit.md_unit_system))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"id": "3e492dfd-a17f-46e4-adb3-741fd67bd1b1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"init_neighbor_list = nbr_fns.allocate(jax_explicit_positions)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"id": "799b5da8-5d53-4be8-a318-b452f39fb24c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(2269, 856)" | |
] | |
}, | |
"execution_count": 39, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"init_neighbor_list.idx.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "3fa0ba0d-319d-4bf2-b6c1-32570c5989f1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"jax_explicit_energy = jax.jit(explicit_u)(jax_explicit_positions, init_neighbor_list, explicit_params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"id": "f5e894b2-6e87-48bb-9c04-2f2e656cea33", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray(-21838.40004107, dtype=float64)" | |
] | |
}, | |
"execution_count": 45, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jax_explicit_energy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6d37b471-0eed-49d0-af5a-bec154832f97", | |
"metadata": {}, | |
"source": [ | |
"get the `OpenMM` energy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"id": "bdba2014-03e3-47fe-9c98-4eb674c4bfef", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"explicit_integrator = LangevinIntegrator()\n", | |
"explicit_context = openmm.Context(adp_explicit.system, explicit_integrator)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"id": "ed719fa8-4e48-4b3b-8d96-ce7640992779", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"explicit_context.setPositions(adp_explicit.positions)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"id": "c7dc28fb-14ff-4d00-bddb-9ad99b20c7ed", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"explicit_context.setPeriodicBoxVectors(*adp_explicit.system.getDefaultPeriodicBoxVectors())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"id": "c4d0baa9-7518-402b-ad2d-5782c2c596c0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"omm_explicit_energy = explicit_context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(unit.md_unit_system)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"id": "0b816020-ea8f-4d8b-ab3c-1bca13d5e06d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"DeviceArray(-2820.75967401, dtype=float64)" | |
] | |
}, | |
"execution_count": 52, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"omm_explicit_energy - jax_explicit_energy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d6598fb1-7a28-4e2f-917a-401aa894bec0", | |
"metadata": {}, | |
"source": [ | |
"so we are discrepant, as we would expect. the `OpenMM` context is using `PME` while we are using a reaction field implementation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5e9fd809-3c1e-45e7-b84c-6da2243455d8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment