Skip to content

Instantly share code, notes, and snippets.

@dominicrufa
Created April 29, 2021 21:39
Show Gist options
  • Save dominicrufa/94397a083ad5f8265f1729bd58f40c6e to your computer and use it in GitHub Desktop.
Save dominicrufa/94397a083ad5f8265f1729bd58f40c6e to your computer and use it in GitHub Desktop.
placing the torchforce into force group 1 and querying the force matrices separately seems to work on a non-pbc system with `Reference` platform
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "650ecdde-a4c6-4443-97f9-c5d6014b5d34",
"metadata": {},
"source": [
"Demonstrate force problem..."
]
},
{
"cell_type": "markdown",
"id": "bf4d72bd-b848-4f39-99b4-2f6468e7c36a",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e412c2c7-f12c-40cc-a0d5-f78511aed521",
"metadata": {},
"outputs": [],
"source": [
"from openmmtools.testsystems import AlanineDipeptideVacuum\n",
"from simtk import unit, openmm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1c64d66b-a553-489e-8371-27a77ee5d589",
"metadata": {},
"outputs": [],
"source": [
"adp = AlanineDipeptideVacuum(constraints=None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "50384e7a-b839-4940-b117-110325665868",
"metadata": {},
"outputs": [],
"source": [
"system, topology, pos = adp.system, adp.topology, adp.positions"
]
},
{
"cell_type": "markdown",
"id": "009ebae8-2da5-4ae9-9167-9cad2d48d9e2",
"metadata": {},
"source": [
"remove the `CMMRotation` Force"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d7f669fd-0238-4d76-81aa-de3d301099a2",
"metadata": {},
"outputs": [],
"source": [
"system.removeForce(system.getNumForces() - 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "74c64573-0561-465d-82e7-fbcb3aedf112",
"metadata": {},
"outputs": [],
"source": [
"import copy"
]
},
{
"cell_type": "markdown",
"id": "be38bc62-af8f-41f2-abda-34df815c9822",
"metadata": {},
"source": [
"make a copy of the system for comparison"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "689a800f-244e-4bbb-bff3-ba2069d102eb",
"metadata": {},
"outputs": [],
"source": [
"ml_system = copy.deepcopy(system)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8cd160ca-e52c-4cff-8ed8-9d7a5f7cd0e0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchani\n",
"import openmmtorch"
]
},
{
"cell_type": "markdown",
"id": "135b4652-e4e6-4ced-9444-93d752bc06d8",
"metadata": {},
"source": [
"# `TorchForce` \n",
"Make a really simple `ForceModule` that is just a harmonic force centered at zero that acts on a subset of atoms: `subset_indices`"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "14d938f8-0182-452d-b6cc-a777ff0eca4e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"class ForceModule(torch.nn.Module):\n",
" \"\"\"A central harmonic potential as a static compute graph\"\"\"\n",
" def __init__(self, subset_indices):\n",
" super().__init__()\n",
" self.indices = torch.tensor(subset_indices, dtype=torch.int64)\n",
" \n",
" def forward(self, positions, scale):\n",
" \"\"\"The forward method returns the energy computed from positions.\n",
"\n",
" Parameters\n",
" ----------\n",
" positions : torch.Tensor with shape (nparticles,3)\n",
" positions[i,k] is the position (in nanometers) of spatial dimension k of particle i\n",
"\n",
" Returns\n",
" -------\n",
" potential : torch.Scalar\n",
" The potential energy (in kJ/mol)\n",
" \"\"\"\n",
" return scale * (positions[self.indices]**2).sum()"
]
},
{
"cell_type": "markdown",
"id": "6af8dabf-7998-4fac-af3d-a1195d7ab510",
"metadata": {},
"source": [
"make a `Torchforce` that acts on the first 3 particles. the torch force will be scaled to zero by default."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4fe71b43-506b-445b-8d59-933348b1b84c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"indices = list(range(3))\n",
"f_gen = ForceModule(indices)\n",
"module = torch.jit.script(f_gen)\n",
"\n",
"# Serialize the compute graph to a file\n",
"save_filename = f\"adp_model.pt\"\n",
"module.save(save_filename)\n",
"\n",
"# Create the TorchForce from the serialized compute graph\n",
"from openmmtorch import TorchForce\n",
"torch_force = TorchForce(save_filename)\n",
"torch_force.setForceGroup(1) #default 0th force group\n",
"torch_force.addGlobalParameter('scale', 1.)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5b08abe9-52ff-452d-b265-28dcccff2a14",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_system.addForce(torch_force)"
]
},
{
"cell_type": "markdown",
"id": "3b114b92-c996-4b6c-aa64-329d8ab94314",
"metadata": {},
"source": [
"print the forces in each system"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b7fb422e-c1f1-4228-a096-df3a801cbf4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f2ac76102a0> >,\n",
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f2ac7610f30> >,\n",
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f2ac7610f60> >,\n",
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f2ac7610b40> >]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"system.getForces()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cac6cbec-ff2f-47e4-afe7-c375b350bfa7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f2ac7610720> >,\n",
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f2ac76106c0> >,\n",
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f2b9c9d6600> >,\n",
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f2b9c9d6360> >,\n",
" <simtk.openmm.openmm.Force; proxy of <Swig Object of type 'OpenMM::Force *' at 0x7f2ac93bac90> >]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_system.getForces()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9b74487e-b538-45dc-89f2-a1fa60436950",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"0\n",
"0\n",
"0\n"
]
}
],
"source": [
"for force in system.getForces():\n",
" print(force.getForceGroup())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "415f3620-2cd4-4eeb-a5d4-ff50cabbb85a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"0\n",
"0\n",
"0\n",
"1\n"
]
}
],
"source": [
"for force in ml_system.getForces():\n",
" print(force.getForceGroup())"
]
},
{
"cell_type": "markdown",
"id": "895943b8-efa5-4b9e-a8af-6feb8ce067af",
"metadata": {},
"source": [
"make some integrators"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "82380227-a342-4e85-801a-d60650483a42",
"metadata": {},
"outputs": [],
"source": [
"from openmmtools.integrators import LangevinIntegrator\n",
"_int = LangevinIntegrator()\n",
"_ml_int = LangevinIntegrator()"
]
},
{
"cell_type": "markdown",
"id": "a382bebc-bd74-4351-a91a-cbace75aec8b",
"metadata": {},
"source": [
"make contexts"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "5bcc1fe5-ae09-4382-a76b-b1286ea758e4",
"metadata": {},
"outputs": [],
"source": [
"platform = openmm.Platform.getPlatformByName('Reference')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "6b5635af-9858-41f2-878a-97f6b1c9656e",
"metadata": {},
"outputs": [],
"source": [
"mm_context = openmm.Context(system, _int, platform)\n",
"ml_context = openmm.Context(ml_system, _ml_int, platform)"
]
},
{
"cell_type": "markdown",
"id": "ef8b1ff3-c651-40e4-b655-3918ba423015",
"metadata": {},
"source": [
"set (same) positions to the contexts "
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "d22f7235-7b7e-4ec9-a45c-2afbd9d84430",
"metadata": {},
"outputs": [],
"source": [
"mm_context.setPositions(pos)\n",
"ml_context.setPositions(pos)"
]
},
{
"cell_type": "markdown",
"id": "bfc1033b-300e-4aaf-a727-9c27b00ec604",
"metadata": {},
"source": [
"get potential energies. they should be the same since the `TorchForce` is contributing no energy"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "544017c8-2e1b-48ab-9cae-9d79362c282d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=-88.08858855730922, unit=kilojoule/mole)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mm_context.getState(getEnergy=True, groups={0}).getPotentialEnergy()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f077bca5-fd24-46e5-adaa-fb136a19ce3a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=-88.08858855730922, unit=kilojoule/mole)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_context.getState(getEnergy=True, groups={0}).getPotentialEnergy()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "ca632cc5-902f-4b48-b0a8-3e2cf2af9b50",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=0.22390250343476703, unit=kilojoule/mole)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_context.getState(getEnergy=True, groups={1}).getPotentialEnergy()"
]
},
{
"cell_type": "markdown",
"id": "9c4881eb-df3a-4bbb-a446-6682766a0238",
"metadata": {},
"source": [
"alright, so energies seem to match...what about forces?"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "48247ee5-8467-4cf0-8998-f5ba377e44f1",
"metadata": {},
"outputs": [],
"source": [
"mm_forces = mm_context.getState(getForces=True).getForces(asNumpy=True)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ad7a1b02-1c40-441a-b840-34c6e9d1a309",
"metadata": {},
"outputs": [],
"source": [
"ml_forces_g0 = ml_context.getState(getForces=True, groups={0}).getForces(asNumpy=True)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "6f364a0f-075d-4eb4-a583-84e4e8d5dd8d",
"metadata": {},
"outputs": [],
"source": [
"ml_forces_g1 = ml_context.getState(getForces=True, groups={1}).getForces(asNumpy=True)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "59dbee28-2291-468d-b0ba-0050301126ea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=array([[0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.],\n",
" [0., 0., 0.]]), unit=kilojoule/(nanometer*mole))"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_forces_g0 - mm_forces"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c8877a2d-52d2-4b41-ae41-0c841e4b1488",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=array([[-4.000002e-01, -2.000000e-01, 2.600000e-07],\n",
" [-4.000002e-01, -4.180000e-01, -2.000000e-08],\n",
" [-2.972528e-01, -4.907698e-01, -1.779648e-01],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n",
" [-0.000000e+00, -0.000000e+00, -0.000000e+00]]), unit=kilojoule/(nanometer*mole))"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_forces_g1"
]
},
{
"cell_type": "markdown",
"id": "911f7b96-d9ca-49b3-8a58-f91222e764c1",
"metadata": {},
"source": [
"what happens if we get _all_ of the forces from the ml_system and compare to the mm_system?"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "46589e7e-cc03-418f-8437-d9001f83f185",
"metadata": {},
"outputs": [],
"source": [
"ml_forces_all = ml_context.getState(getForces=True).getForces(asNumpy=True)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "343c9b03-06ec-4bf5-90cc-d6ea149be071",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=array([[-1.72264425e+02, -3.20532920e+01, 6.93394701e-01],\n",
" [-2.48678096e+02, -4.06443066e+02, -2.29517361e+00],\n",
" [-1.41608541e+01, 6.58371137e+01, 2.84173879e+01],\n",
" [-1.53737623e+01, 6.55453843e+01, -2.69094910e+01],\n",
" [ 4.48643202e+02, 2.27886698e+02, -1.37703175e+02],\n",
" [ 6.63423220e+02, 4.01344208e+02, -3.88143417e+02],\n",
" [ 8.31217553e+01, 7.90119651e+02, 2.40789270e+02],\n",
" [-4.80724161e+01, -4.43541416e+01, 7.38061773e+00],\n",
" [-3.89952746e+02, -3.96431949e+02, -5.37281758e+01],\n",
" [ 4.11710161e+01, 1.08638512e+01, -5.63739210e+01],\n",
" [-4.00936776e+01, 1.73426599e+02, 2.98724464e+02],\n",
" [-1.77504486e+01, 2.76649441e+01, 6.28452536e+01],\n",
" [-8.25925159e+01, 2.01081195e+02, 1.30535198e+02],\n",
" [-4.08666997e+02, -2.61975999e+02, 3.33849425e+02],\n",
" [ 5.72445758e+01, -1.93925039e+01, -2.79065483e+02],\n",
" [ 2.16593316e+02, 4.39662705e+01, -8.95231221e+00],\n",
" [ 6.24423858e+01, -3.44132577e+02, -1.59405641e+02],\n",
" [ 1.40967719e+02, 1.05188962e+02, 6.88440585e+00],\n",
" [-2.74913367e+02, -2.02534311e+02, 1.60224135e-01],\n",
" [-3.92009457e+01, -2.84703180e+02, 9.29432744e-01],\n",
" [ 1.83047125e+01, -6.15969357e+01, -2.15378355e+01],\n",
" [ 1.87110953e+01, -6.04156914e+01, 2.27275861e+01]]), unit=kilojoule/(nanometer*mole))"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_forces_all - mm_forces"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4697a8fd-06b6-45cc-9f6f-aacd5d6e9173",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment