Created
April 29, 2021 21:39
-
-
Save dominicrufa/94397a083ad5f8265f1729bd58f40c6e to your computer and use it in GitHub Desktop.
placing the torchforce into force group 1 and querying the force matrices separately seems to work on a non-pbc system with `Reference` platform
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "650ecdde-a4c6-4443-97f9-c5d6014b5d34", | |
"metadata": {}, | |
"source": [ | |
"Demonstrate force problem..." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bf4d72bd-b848-4f39-99b4-2f6468e7c36a", | |
"metadata": {}, | |
"source": [ | |
"# Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "e412c2c7-f12c-40cc-a0d5-f78511aed521", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from openmmtools.testsystems import AlanineDipeptideVacuum\n", | |
"from simtk import unit, openmm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "1c64d66b-a553-489e-8371-27a77ee5d589", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"adp = AlanineDipeptideVacuum(constraints=None)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "50384e7a-b839-4940-b117-110325665868", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"system, topology, pos = adp.system, adp.topology, adp.positions" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "009ebae8-2da5-4ae9-9167-9cad2d48d9e2", | |
"metadata": {}, | |
"source": [ | |
"remove the `CMMRotation` Force" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "d7f669fd-0238-4d76-81aa-de3d301099a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"system.removeForce(system.getNumForces() - 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "74c64573-0561-465d-82e7-fbcb3aedf112", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import copy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "be38bc62-af8f-41f2-abda-34df815c9822", | |
"metadata": {}, | |
"source": [ | |
"make a copy of the system for comparison" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "689a800f-244e-4bbb-bff3-ba2069d102eb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ml_system = copy.deepcopy(system)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "8cd160ca-e52c-4cff-8ed8-9d7a5f7cd0e0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torchani\n", | |
"import openmmtorch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "135b4652-e4e6-4ced-9444-93d752bc06d8", | |
"metadata": {}, | |
"source": [ | |
"# `TorchForce` \n", | |
"Make a really simple `ForceModule` that is just a harmonic force centered at zero that acts on a subset of atoms: `subset_indices`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "14d938f8-0182-452d-b6cc-a777ff0eca4e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"class ForceModule(torch.nn.Module):\n", | |
" \"\"\"A central harmonic potential as a static compute graph\"\"\"\n", | |
" def __init__(self, subset_indices):\n", | |
" super().__init__()\n", | |
" self.indices = torch.tensor(subset_indices, dtype=torch.int64)\n", | |
" \n", | |
" def forward(self, positions, scale):\n", | |
" \"\"\"The forward method returns the energy computed from positions.\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" positions : torch.Tensor with shape (nparticles,3)\n", | |
" positions[i,k] is the position (in nanometers) of spatial dimension k of particle i\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" potential : torch.Scalar\n", | |
" The potential energy (in kJ/mol)\n", | |
" \"\"\"\n", | |
" return scale * (positions[self.indices]**2).sum()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6af8dabf-7998-4fac-af3d-a1195d7ab510", | |
"metadata": {}, | |
"source": [ | |
"make a `Torchforce` that acts on the first 3 particles. the torch force will be scaled to zero by default." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "4fe71b43-506b-445b-8d59-933348b1b84c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"indices = list(range(3))\n", | |
"f_gen = ForceModule(indices)\n", | |
"module = torch.jit.script(f_gen)\n", | |
"\n", | |
"# Serialize the compute graph to a file\n", | |
"save_filename = f\"adp_model.pt\"\n", | |
"module.save(save_filename)\n", | |
"\n", | |
"# Create the TorchForce from the serialized compute graph\n", | |
"from openmmtorch import TorchForce\n", | |
"torch_force = TorchForce(save_filename)\n", | |
"torch_force.setForceGroup(1) #default 0th force group\n", | |
"torch_force.addGlobalParameter('scale', 1.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "5b08abe9-52ff-452d-b265-28dcccff2a14", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"4" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_system.addForce(torch_force)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3b114b92-c996-4b6c-aa64-329d8ab94314", | |
"metadata": {}, | |
"source": [ | |
"print the forces in each system" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "b7fb422e-c1f1-4228-a096-df3a801cbf4b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f2ac76102a0> >,\n", | |
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f2ac7610f30> >,\n", | |
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f2ac7610f60> >,\n", | |
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f2ac7610b40> >]" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"system.getForces()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "cac6cbec-ff2f-47e4-afe7-c375b350bfa7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7f2ac7610720> >,\n", | |
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7f2ac76106c0> >,\n", | |
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7f2b9c9d6600> >,\n", | |
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7f2b9c9d6360> >,\n", | |
" <simtk.openmm.openmm.Force; proxy of <Swig Object of type 'OpenMM::Force *' at 0x7f2ac93bac90> >]" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_system.getForces()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "9b74487e-b538-45dc-89f2-a1fa60436950", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0\n", | |
"0\n", | |
"0\n", | |
"0\n" | |
] | |
} | |
], | |
"source": [ | |
"for force in system.getForces():\n", | |
" print(force.getForceGroup())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "415f3620-2cd4-4eeb-a5d4-ff50cabbb85a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0\n", | |
"0\n", | |
"0\n", | |
"0\n", | |
"1\n" | |
] | |
} | |
], | |
"source": [ | |
"for force in ml_system.getForces():\n", | |
" print(force.getForceGroup())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "895943b8-efa5-4b9e-a8af-6feb8ce067af", | |
"metadata": {}, | |
"source": [ | |
"make some integrators" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "82380227-a342-4e85-801a-d60650483a42", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from openmmtools.integrators import LangevinIntegrator\n", | |
"_int = LangevinIntegrator()\n", | |
"_ml_int = LangevinIntegrator()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a382bebc-bd74-4351-a91a-cbace75aec8b", | |
"metadata": {}, | |
"source": [ | |
"make contexts" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "5bcc1fe5-ae09-4382-a76b-b1286ea758e4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"platform = openmm.Platform.getPlatformByName('Reference')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "6b5635af-9858-41f2-878a-97f6b1c9656e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_context = openmm.Context(system, _int, platform)\n", | |
"ml_context = openmm.Context(ml_system, _ml_int, platform)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ef8b1ff3-c651-40e4-b655-3918ba423015", | |
"metadata": {}, | |
"source": [ | |
"set (same) positions to the contexts " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "d22f7235-7b7e-4ec9-a45c-2afbd9d84430", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_context.setPositions(pos)\n", | |
"ml_context.setPositions(pos)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bfc1033b-300e-4aaf-a727-9c27b00ec604", | |
"metadata": {}, | |
"source": [ | |
"get potential energies. they should be the same since the `TorchForce` is contributing no energy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "544017c8-2e1b-48ab-9cae-9d79362c282d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=-88.08858855730922, unit=kilojoule/mole)" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mm_context.getState(getEnergy=True, groups={0}).getPotentialEnergy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "f077bca5-fd24-46e5-adaa-fb136a19ce3a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=-88.08858855730922, unit=kilojoule/mole)" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_context.getState(getEnergy=True, groups={0}).getPotentialEnergy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "ca632cc5-902f-4b48-b0a8-3e2cf2af9b50", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=0.22390250343476703, unit=kilojoule/mole)" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_context.getState(getEnergy=True, groups={1}).getPotentialEnergy()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9c4881eb-df3a-4bbb-a446-6682766a0238", | |
"metadata": {}, | |
"source": [ | |
"alright, so energies seem to match...what about forces?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "48247ee5-8467-4cf0-8998-f5ba377e44f1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_forces = mm_context.getState(getForces=True).getForces(asNumpy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "ad7a1b02-1c40-441a-b840-34c6e9d1a309", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ml_forces_g0 = ml_context.getState(getForces=True, groups={0}).getForces(asNumpy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "6f364a0f-075d-4eb4-a583-84e4e8d5dd8d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ml_forces_g1 = ml_context.getState(getForces=True, groups={1}).getForces(asNumpy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "59dbee28-2291-468d-b0ba-0050301126ea", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=array([[0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.],\n", | |
" [0., 0., 0.]]), unit=kilojoule/(nanometer*mole))" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_forces_g0 - mm_forces" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "c8877a2d-52d2-4b41-ae41-0c841e4b1488", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=array([[-4.000002e-01, -2.000000e-01, 2.600000e-07],\n", | |
" [-4.000002e-01, -4.180000e-01, -2.000000e-08],\n", | |
" [-2.972528e-01, -4.907698e-01, -1.779648e-01],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00],\n", | |
" [-0.000000e+00, -0.000000e+00, -0.000000e+00]]), unit=kilojoule/(nanometer*mole))" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_forces_g1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "911f7b96-d9ca-49b3-8a58-f91222e764c1", | |
"metadata": {}, | |
"source": [ | |
"what happens if we get _all_ of the forces from the ml_system and compare to the mm_system?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "46589e7e-cc03-418f-8437-d9001f83f185", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ml_forces_all = ml_context.getState(getForces=True).getForces(asNumpy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "343c9b03-06ec-4bf5-90cc-d6ea149be071", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=array([[-1.72264425e+02, -3.20532920e+01, 6.93394701e-01],\n", | |
" [-2.48678096e+02, -4.06443066e+02, -2.29517361e+00],\n", | |
" [-1.41608541e+01, 6.58371137e+01, 2.84173879e+01],\n", | |
" [-1.53737623e+01, 6.55453843e+01, -2.69094910e+01],\n", | |
" [ 4.48643202e+02, 2.27886698e+02, -1.37703175e+02],\n", | |
" [ 6.63423220e+02, 4.01344208e+02, -3.88143417e+02],\n", | |
" [ 8.31217553e+01, 7.90119651e+02, 2.40789270e+02],\n", | |
" [-4.80724161e+01, -4.43541416e+01, 7.38061773e+00],\n", | |
" [-3.89952746e+02, -3.96431949e+02, -5.37281758e+01],\n", | |
" [ 4.11710161e+01, 1.08638512e+01, -5.63739210e+01],\n", | |
" [-4.00936776e+01, 1.73426599e+02, 2.98724464e+02],\n", | |
" [-1.77504486e+01, 2.76649441e+01, 6.28452536e+01],\n", | |
" [-8.25925159e+01, 2.01081195e+02, 1.30535198e+02],\n", | |
" [-4.08666997e+02, -2.61975999e+02, 3.33849425e+02],\n", | |
" [ 5.72445758e+01, -1.93925039e+01, -2.79065483e+02],\n", | |
" [ 2.16593316e+02, 4.39662705e+01, -8.95231221e+00],\n", | |
" [ 6.24423858e+01, -3.44132577e+02, -1.59405641e+02],\n", | |
" [ 1.40967719e+02, 1.05188962e+02, 6.88440585e+00],\n", | |
" [-2.74913367e+02, -2.02534311e+02, 1.60224135e-01],\n", | |
" [-3.92009457e+01, -2.84703180e+02, 9.29432744e-01],\n", | |
" [ 1.83047125e+01, -6.15969357e+01, -2.15378355e+01],\n", | |
" [ 1.87110953e+01, -6.04156914e+01, 2.27275861e+01]]), unit=kilojoule/(nanometer*mole))" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_forces_all - mm_forces" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4697a8fd-06b6-45cc-9f6f-aacd5d6e9173", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment