Created
April 29, 2021 19:57
-
-
Save dominicrufa/93997eb77c365022712e9f1beabcaddf to your computer and use it in GitHub Desktop.
simpler `TorchForce` implementation on a simple vacuum dipeptide that fails in force calculation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "650ecdde-a4c6-4443-97f9-c5d6014b5d34", | |
"metadata": {}, | |
"source": [ | |
"Demonstrate force problem..." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bf4d72bd-b848-4f39-99b4-2f6468e7c36a", | |
"metadata": {}, | |
"source": [ | |
"# Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "e412c2c7-f12c-40cc-a0d5-f78511aed521", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from openmmtools.testsystems import AlanineDipeptideVacuum\n", | |
"from simtk import unit, openmm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "1c64d66b-a553-489e-8371-27a77ee5d589", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"adp = AlanineDipeptideVacuum(constraints=None)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "50384e7a-b839-4940-b117-110325665868", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"system, topology, pos = adp.system, adp.topology, adp.positions" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "009ebae8-2da5-4ae9-9167-9cad2d48d9e2", | |
"metadata": {}, | |
"source": [ | |
"remove the `CMMRotation` Force" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "d7f669fd-0238-4d76-81aa-de3d301099a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"system.removeForce(system.getNumForces() - 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "74c64573-0561-465d-82e7-fbcb3aedf112", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import copy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "be38bc62-af8f-41f2-abda-34df815c9822", | |
"metadata": {}, | |
"source": [ | |
"make a copy of the system for comparison" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "689a800f-244e-4bbb-bff3-ba2069d102eb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ml_system = copy.deepcopy(system)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "8cd160ca-e52c-4cff-8ed8-9d7a5f7cd0e0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torchani\n", | |
"import openmmtorch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "135b4652-e4e6-4ced-9444-93d752bc06d8", | |
"metadata": {}, | |
"source": [ | |
"# `TorchForce` \n", | |
"Make a really simple `ForceModule` that is just a harmonic force centered at zero that acts on a subset of atoms: `subset_indices`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "14d938f8-0182-452d-b6cc-a777ff0eca4e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"class ForceModule(torch.nn.Module):\n", | |
" \"\"\"A central harmonic potential as a static compute graph\"\"\"\n", | |
" def __init__(self, subset_indices):\n", | |
" super().__init__()\n", | |
" self.indices = torch.tensor(subset_indices, dtype=torch.int64)\n", | |
" \n", | |
" def forward(self, positions, scale):\n", | |
" \"\"\"The forward method returns the energy computed from positions.\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" positions : torch.Tensor with shape (nparticles,3)\n", | |
" positions[i,k] is the position (in nanometers) of spatial dimension k of particle i\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" potential : torch.Scalar\n", | |
" The potential energy (in kJ/mol)\n", | |
" \"\"\"\n", | |
" return scale * (positions[self.indices]**2).sum()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6af8dabf-7998-4fac-af3d-a1195d7ab510", | |
"metadata": {}, | |
"source": [ | |
"make a `Torchforce` that acts on the first 3 particles. the torch force will be scaled to zero by default." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "4fe71b43-506b-445b-8d59-933348b1b84c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"indices = list(range(3))\n", | |
"f_gen = ForceModule(indices)\n", | |
"module = torch.jit.script(f_gen)\n", | |
"\n", | |
"# Serialize the compute graph to a file\n", | |
"save_filename = f\"adp_model.pt\"\n", | |
"module.save(save_filename)\n", | |
"\n", | |
"# Create the TorchForce from the serialized compute graph\n", | |
"from openmmtorch import TorchForce\n", | |
"torch_force = TorchForce(save_filename)\n", | |
"torch_force.setForceGroup(0) #default 0th force group\n", | |
"torch_force.addGlobalParameter('scale', 0.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "5b08abe9-52ff-452d-b265-28dcccff2a14", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"4" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_system.addForce(torch_force)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3b114b92-c996-4b6c-aa64-329d8ab94314", | |
"metadata": {}, | |
"source": [ | |
"print the forces in each system" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "b7fb422e-c1f1-4228-a096-df3a801cbf4b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7fbd66f30c30> >,\n", | |
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7fbd66f30810> >,\n", | |
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7fbd66f30360> >,\n", | |
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7fbd66f30960> >]" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"system.getForces()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "cac6cbec-ff2f-47e4-afe7-c375b350bfa7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7fbd66f30a50> >,\n", | |
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7fbd66f309c0> >,\n", | |
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7fbd66f30c00> >,\n", | |
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7fbd66f30bd0> >,\n", | |
" <simtk.openmm.openmm.Force; proxy of <Swig Object of type 'OpenMM::Force *' at 0x7fbd66f304e0> >]" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_system.getForces()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "895943b8-efa5-4b9e-a8af-6feb8ce067af", | |
"metadata": {}, | |
"source": [ | |
"make some integrators" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "82380227-a342-4e85-801a-d60650483a42", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from openmmtools.integrators import LangevinIntegrator\n", | |
"_int = LangevinIntegrator()\n", | |
"_ml_int = LangevinIntegrator()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a382bebc-bd74-4351-a91a-cbace75aec8b", | |
"metadata": {}, | |
"source": [ | |
"make contexts" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "6b5635af-9858-41f2-878a-97f6b1c9656e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_context = openmm.Context(system, _int)\n", | |
"ml_context = openmm.Context(ml_system, _ml_int)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ef8b1ff3-c651-40e4-b655-3918ba423015", | |
"metadata": {}, | |
"source": [ | |
"set (same) positions to the contexts " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "d22f7235-7b7e-4ec9-a45c-2afbd9d84430", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_context.setPositions(pos)\n", | |
"ml_context.setPositions(pos)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bfc1033b-300e-4aaf-a727-9c27b00ec604", | |
"metadata": {}, | |
"source": [ | |
"get potential energies. they should be the same since the `TorchForce` is contributing no energy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "544017c8-2e1b-48ab-9cae-9d79362c282d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=-88.08855399568873, unit=kilojoule/mole)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mm_context.getState(getEnergy=True).getPotentialEnergy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "f077bca5-fd24-46e5-adaa-fb136a19ce3a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=-88.08855399568873, unit=kilojoule/mole)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_context.getState(getEnergy=True).getPotentialEnergy()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9c4881eb-df3a-4bbb-a446-6682766a0238", | |
"metadata": {}, | |
"source": [ | |
"alright, so energies seem to match...what about forces?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "48247ee5-8467-4cf0-8998-f5ba377e44f1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_forces = mm_context.getState(getForces=True).getForces(asNumpy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "6f364a0f-075d-4eb4-a583-84e4e8d5dd8d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ml_forces = ml_context.getState(getForces=True).getForces(asNumpy=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8ebf6252-3861-4099-bbec-874fb2fcb6bb", | |
"metadata": {}, | |
"source": [ | |
"the difference in the force matrices should give zero matrix, right?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "c0deb3b9-b48f-4d01-9428-ff606d777594", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=array([[-1.87751330e+02, -6.41684342e+01, 1.86482088e-03],\n", | |
" [-2.00084371e+02, -4.13052296e+02, 2.83677570e-03],\n", | |
" [-5.44396156e+01, 2.32343436e+01, 4.86549842e+01],\n", | |
" [-5.44425922e+01, 2.32387553e+01, -4.86600834e+01],\n", | |
" [ 3.03484463e+02, 4.96798282e+02, -3.76411002e+01],\n", | |
" [ 3.34757006e+02, 3.37026485e+01, -6.39082298e-03],\n", | |
" [ 1.72092962e+02, 8.75110505e+02, 2.40789389e+02],\n", | |
" [ 1.61614234e+02, 5.05062095e+02, 3.20766650e+01],\n", | |
" [-3.86606478e+02, -3.97636949e+02, -5.37281961e+01],\n", | |
" [-1.41979770e+01, -9.60443248e+01, -6.50965966e+01],\n", | |
" [ 1.57314813e+02, 5.52322164e+02, 1.94980970e+02],\n", | |
" [-1.21625327e+01, 5.51552203e+01, 7.97986801e+01],\n", | |
" [-6.05816985e+01, 5.90779803e+01, 4.91643422e+01],\n", | |
" [-1.04175384e+01, 3.85627822e+01, 2.42168774e+01],\n", | |
" [ 2.62386815e+01, -3.09807864e+02, -2.79066010e+02],\n", | |
" [ 1.23115638e+02, -5.06091871e+02, -5.50782819e+01],\n", | |
" [-1.03733248e+02, -3.55097607e+02, -1.30408658e+02],\n", | |
" [ 1.30134765e+02, 8.77416694e+01, 1.02573344e-03],\n", | |
" [-2.80516409e+02, -2.07134668e+02, -9.20385053e-04],\n", | |
" [ 1.54334843e+01, -2.02740860e+02, -6.03112253e-03],\n", | |
" [-2.96217783e+01, -9.91151295e+01, -3.83140799e+01],\n", | |
" [-2.96305036e+01, -9.91166618e+01, 3.83187311e+01]]), unit=kilojoule/(nanometer*mole))" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_forces - mm_forces" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1eb07d15-0499-4e43-8e4c-820e156d5239", | |
"metadata": {}, | |
"source": [ | |
"well, that isn't right..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "cfdcdf0f-3833-41ed-8a34-837a38796ced", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment