Skip to content

Instantly share code, notes, and snippets.

@dominicrufa
Created April 29, 2021 19:57
Show Gist options
  • Save dominicrufa/93997eb77c365022712e9f1beabcaddf to your computer and use it in GitHub Desktop.
Save dominicrufa/93997eb77c365022712e9f1beabcaddf to your computer and use it in GitHub Desktop.
simpler `TorchForce` implementation on a simple vacuum dipeptide that fails in force calculation.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "650ecdde-a4c6-4443-97f9-c5d6014b5d34",
"metadata": {},
"source": [
"Demonstrate force problem..."
]
},
{
"cell_type": "markdown",
"id": "bf4d72bd-b848-4f39-99b4-2f6468e7c36a",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e412c2c7-f12c-40cc-a0d5-f78511aed521",
"metadata": {},
"outputs": [],
"source": [
"from openmmtools.testsystems import AlanineDipeptideVacuum\n",
"from simtk import unit, openmm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1c64d66b-a553-489e-8371-27a77ee5d589",
"metadata": {},
"outputs": [],
"source": [
"adp = AlanineDipeptideVacuum(constraints=None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "50384e7a-b839-4940-b117-110325665868",
"metadata": {},
"outputs": [],
"source": [
"system, topology, pos = adp.system, adp.topology, adp.positions"
]
},
{
"cell_type": "markdown",
"id": "009ebae8-2da5-4ae9-9167-9cad2d48d9e2",
"metadata": {},
"source": [
"remove the `CMMRotation` Force"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d7f669fd-0238-4d76-81aa-de3d301099a2",
"metadata": {},
"outputs": [],
"source": [
"system.removeForce(system.getNumForces() - 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "74c64573-0561-465d-82e7-fbcb3aedf112",
"metadata": {},
"outputs": [],
"source": [
"import copy"
]
},
{
"cell_type": "markdown",
"id": "be38bc62-af8f-41f2-abda-34df815c9822",
"metadata": {},
"source": [
"make a copy of the system for comparison"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "689a800f-244e-4bbb-bff3-ba2069d102eb",
"metadata": {},
"outputs": [],
"source": [
"ml_system = copy.deepcopy(system)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8cd160ca-e52c-4cff-8ed8-9d7a5f7cd0e0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchani\n",
"import openmmtorch"
]
},
{
"cell_type": "markdown",
"id": "135b4652-e4e6-4ced-9444-93d752bc06d8",
"metadata": {},
"source": [
"# `TorchForce` \n",
"Make a really simple `ForceModule` that is just a harmonic force centered at zero that acts on a subset of atoms: `subset_indices`"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "14d938f8-0182-452d-b6cc-a777ff0eca4e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"class ForceModule(torch.nn.Module):\n",
" \"\"\"A central harmonic potential as a static compute graph\"\"\"\n",
" def __init__(self, subset_indices):\n",
" super().__init__()\n",
" self.indices = torch.tensor(subset_indices, dtype=torch.int64)\n",
" \n",
" def forward(self, positions, scale):\n",
" \"\"\"The forward method returns the energy computed from positions.\n",
"\n",
" Parameters\n",
" ----------\n",
" positions : torch.Tensor with shape (nparticles,3)\n",
" positions[i,k] is the position (in nanometers) of spatial dimension k of particle i\n",
"\n",
" Returns\n",
" -------\n",
" potential : torch.Scalar\n",
" The potential energy (in kJ/mol)\n",
" \"\"\"\n",
" return scale * (positions[self.indices]**2).sum()"
]
},
{
"cell_type": "markdown",
"id": "6af8dabf-7998-4fac-af3d-a1195d7ab510",
"metadata": {},
"source": [
"make a `Torchforce` that acts on the first 3 particles. the torch force will be scaled to zero by default."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4fe71b43-506b-445b-8d59-933348b1b84c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"indices = list(range(3))\n",
"f_gen = ForceModule(indices)\n",
"module = torch.jit.script(f_gen)\n",
"\n",
"# Serialize the compute graph to a file\n",
"save_filename = f\"adp_model.pt\"\n",
"module.save(save_filename)\n",
"\n",
"# Create the TorchForce from the serialized compute graph\n",
"from openmmtorch import TorchForce\n",
"torch_force = TorchForce(save_filename)\n",
"torch_force.setForceGroup(0) #default 0th force group\n",
"torch_force.addGlobalParameter('scale', 0.)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5b08abe9-52ff-452d-b265-28dcccff2a14",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_system.addForce(torch_force)"
]
},
{
"cell_type": "markdown",
"id": "3b114b92-c996-4b6c-aa64-329d8ab94314",
"metadata": {},
"source": [
"print the forces in each system"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b7fb422e-c1f1-4228-a096-df3a801cbf4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7fbd66f30c30> >,\n",
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7fbd66f30810> >,\n",
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7fbd66f30360> >,\n",
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7fbd66f30960> >]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"system.getForces()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cac6cbec-ff2f-47e4-afe7-c375b350bfa7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x7fbd66f30a50> >,\n",
" <simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x7fbd66f309c0> >,\n",
" <simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x7fbd66f30c00> >,\n",
" <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x7fbd66f30bd0> >,\n",
" <simtk.openmm.openmm.Force; proxy of <Swig Object of type 'OpenMM::Force *' at 0x7fbd66f304e0> >]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_system.getForces()"
]
},
{
"cell_type": "markdown",
"id": "895943b8-efa5-4b9e-a8af-6feb8ce067af",
"metadata": {},
"source": [
"make some integrators"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "82380227-a342-4e85-801a-d60650483a42",
"metadata": {},
"outputs": [],
"source": [
"from openmmtools.integrators import LangevinIntegrator\n",
"_int = LangevinIntegrator()\n",
"_ml_int = LangevinIntegrator()"
]
},
{
"cell_type": "markdown",
"id": "a382bebc-bd74-4351-a91a-cbace75aec8b",
"metadata": {},
"source": [
"make contexts"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6b5635af-9858-41f2-878a-97f6b1c9656e",
"metadata": {},
"outputs": [],
"source": [
"mm_context = openmm.Context(system, _int)\n",
"ml_context = openmm.Context(ml_system, _ml_int)"
]
},
{
"cell_type": "markdown",
"id": "ef8b1ff3-c651-40e4-b655-3918ba423015",
"metadata": {},
"source": [
"set (same) positions to the contexts "
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d22f7235-7b7e-4ec9-a45c-2afbd9d84430",
"metadata": {},
"outputs": [],
"source": [
"mm_context.setPositions(pos)\n",
"ml_context.setPositions(pos)"
]
},
{
"cell_type": "markdown",
"id": "bfc1033b-300e-4aaf-a727-9c27b00ec604",
"metadata": {},
"source": [
"get potential energies. they should be the same since the `TorchForce` is contributing no energy"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "544017c8-2e1b-48ab-9cae-9d79362c282d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=-88.08855399568873, unit=kilojoule/mole)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mm_context.getState(getEnergy=True).getPotentialEnergy()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f077bca5-fd24-46e5-adaa-fb136a19ce3a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=-88.08855399568873, unit=kilojoule/mole)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_context.getState(getEnergy=True).getPotentialEnergy()"
]
},
{
"cell_type": "markdown",
"id": "9c4881eb-df3a-4bbb-a446-6682766a0238",
"metadata": {},
"source": [
"alright, so energies seem to match...what about forces?"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "48247ee5-8467-4cf0-8998-f5ba377e44f1",
"metadata": {},
"outputs": [],
"source": [
"mm_forces = mm_context.getState(getForces=True).getForces(asNumpy=True)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "6f364a0f-075d-4eb4-a583-84e4e8d5dd8d",
"metadata": {},
"outputs": [],
"source": [
"ml_forces = ml_context.getState(getForces=True).getForces(asNumpy=True)"
]
},
{
"cell_type": "markdown",
"id": "8ebf6252-3861-4099-bbec-874fb2fcb6bb",
"metadata": {},
"source": [
"the difference in the force matrices should give zero matrix, right?"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "c0deb3b9-b48f-4d01-9428-ff606d777594",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=array([[-1.87751330e+02, -6.41684342e+01, 1.86482088e-03],\n",
" [-2.00084371e+02, -4.13052296e+02, 2.83677570e-03],\n",
" [-5.44396156e+01, 2.32343436e+01, 4.86549842e+01],\n",
" [-5.44425922e+01, 2.32387553e+01, -4.86600834e+01],\n",
" [ 3.03484463e+02, 4.96798282e+02, -3.76411002e+01],\n",
" [ 3.34757006e+02, 3.37026485e+01, -6.39082298e-03],\n",
" [ 1.72092962e+02, 8.75110505e+02, 2.40789389e+02],\n",
" [ 1.61614234e+02, 5.05062095e+02, 3.20766650e+01],\n",
" [-3.86606478e+02, -3.97636949e+02, -5.37281961e+01],\n",
" [-1.41979770e+01, -9.60443248e+01, -6.50965966e+01],\n",
" [ 1.57314813e+02, 5.52322164e+02, 1.94980970e+02],\n",
" [-1.21625327e+01, 5.51552203e+01, 7.97986801e+01],\n",
" [-6.05816985e+01, 5.90779803e+01, 4.91643422e+01],\n",
" [-1.04175384e+01, 3.85627822e+01, 2.42168774e+01],\n",
" [ 2.62386815e+01, -3.09807864e+02, -2.79066010e+02],\n",
" [ 1.23115638e+02, -5.06091871e+02, -5.50782819e+01],\n",
" [-1.03733248e+02, -3.55097607e+02, -1.30408658e+02],\n",
" [ 1.30134765e+02, 8.77416694e+01, 1.02573344e-03],\n",
" [-2.80516409e+02, -2.07134668e+02, -9.20385053e-04],\n",
" [ 1.54334843e+01, -2.02740860e+02, -6.03112253e-03],\n",
" [-2.96217783e+01, -9.91151295e+01, -3.83140799e+01],\n",
" [-2.96305036e+01, -9.91166618e+01, 3.83187311e+01]]), unit=kilojoule/(nanometer*mole))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_forces - mm_forces"
]
},
{
"cell_type": "markdown",
"id": "1eb07d15-0499-4e43-8e4c-820e156d5239",
"metadata": {},
"source": [
"well, that isn't right..."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cfdcdf0f-3833-41ed-8a34-837a38796ced",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment