Created
April 27, 2021 03:41
-
-
Save dominicrufa/a4e1ed2188de53fb74449e4ce7bee04d to your computer and use it in GitHub Desktop.
reproducing error wherein adding a `TorchForce` to a system gives force matrices that are discrepant w.r.t. mm-only system when the scaling factor is set to zero.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "95a44844-eff2-4546-a480-5ba0ea749154", | |
"metadata": {}, | |
"source": [ | |
"Reproducing strange force error wherein adding a torch force to system, setting the positions, and getting the force matrix yields major discrepancies in energy." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "8e095550-10d2-4166-bfd0-4f1300a68e6c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import simtk\n", | |
"from simtk import openmm\n", | |
"from simtk import unit\n", | |
"import copy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "e17c9e97-40d3-4cea-8407-4aec0b4dd782", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from openmmtools.testsystems import HostGuestExplicit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "db2101fe-5884-4794-9331-b3bf30ab6553", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"T = 300*unit.kelvin" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "b01e046a-3929-4353-9316-c5ab5b146341", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"hge = HostGuestExplicit(constraints=None)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "493d272e-581f-4e8b-bfde-c7eb779ec0d3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"system, positions, topology = hge.system, hge.positions, hge.topology" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c84f408-70ee-4aa3-aeaf-c338c921fe81", | |
"metadata": {}, | |
"source": [ | |
"make a deepcopy of the system to yield an ml_system that will have a torch force added to it" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "6ea3de76-fc92-4fbb-a7d1-8aab7a237df5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ml_system = copy.deepcopy(system)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "db31ebee-ce60-4697-8b9e-706163ca4b14", | |
"metadata": {}, | |
"source": [ | |
"wrap a simple torchani force that has a scaling factor that is a global parameter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "a22362da-4228-4251-97b8-b7ce23bd15a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def make_torchforce(topology, \n", | |
" atoms, \n", | |
" model_name='ani2x', \n", | |
" save_filename = 'animodel.pt', \n", | |
" torch_scale_name='torch_scale', \n", | |
" torch_scale_default_value=0.):\n", | |
" \"\"\"\n", | |
" creates a scalable (via a global parameter) openmm.TorchForce.\n", | |
" Note, the openmm.TorchForce is non periodic.\n", | |
" \n", | |
" arguments\n", | |
" topology : openmm.Topology\n", | |
" topology corresponding to the openmm.System object\n", | |
" atoms : list(int)\n", | |
" list of particle indices that will be included in the TorchForce\n", | |
" model_name : str, default `ani2x`\n", | |
" the name of the model that wille build the torchforce\n", | |
" save_filename : str, default `animodel.pt`\n", | |
" torch module name to save\n", | |
" torch_scale_name : str, default 'torch_scale'\n", | |
" the name of the global parameter that scales the TorchForce\n", | |
" torch_scale_default_value : float, default 1.\n", | |
" the default value of the `torch_scale_name`\n", | |
" \n", | |
" returns \n", | |
" f_gen : openmm.TorchForce\n", | |
" the generated TorchForce\n", | |
" \n", | |
" \"\"\"\n", | |
" import torch\n", | |
" import torchani\n", | |
" import openmmtorch\n", | |
" if model_name == 'ani1ccx':\n", | |
" model = torchani.models.ANI1ccx()\n", | |
" elif model_name == 'ani2x':\n", | |
" model = torchani.models.ANI2x()\n", | |
" else:\n", | |
" raise Exception(f\"model name {model_name} is not currently supported\")\n", | |
"\n", | |
" \n", | |
" # Create the PyTorch model that will be invoked by OpenMM.\n", | |
" includedAtoms = list(topology.atoms())\n", | |
" if atoms is not None:\n", | |
" includedAtoms = [includedAtoms[i] for i in atoms]\n", | |
" elements = [atom.element.symbol for atom in includedAtoms]\n", | |
" print(f\"elements: {elements}\")\n", | |
" species = model.species_to_tensor(elements).unsqueeze(0)\n", | |
" print(f\"species: {species}\")\n", | |
" #indices = torch.tensor(atoms, dtype=torch.int64) #get the atom indices which to pull\n", | |
" \n", | |
" class ANIForce(torch.nn.Module):\n", | |
" def __init__(self, indices, model, species):\n", | |
" super().__init__()\n", | |
" self.energyScale = torchani.units.hartree2kjoulemol(1)\n", | |
" self.indices = torch.tensor(indices, dtype=torch.int64)\n", | |
" #self.torch_indices = torch.tensor(self.indices)\n", | |
" self.model = model\n", | |
" self.species = species\n", | |
" \n", | |
" def forward(self, positions, scale):\n", | |
" #print(f\"indices to query: {self.indices}\")\n", | |
" positions = positions.to(torch.float32) #to float\n", | |
" # in_positions = torch.index_select(positions, 0, self.indices) #select the appropriate indices\n", | |
" in_positions = positions[self.indices]\n", | |
" # print(f\"in_positions has the following shape: {in_positions.shape}\")\n", | |
" _, energy = self.model((self.species, 10.0 * in_positions.unsqueeze(0))) #get the energy\n", | |
" #energy = _energy.sum()\n", | |
" out = energy * scale * self.energyScale\n", | |
" # print(f\"energy: {energy}; out: {out}\")\n", | |
" return out\n", | |
" \n", | |
" f_gen = ANIForce(atoms, model, species)\n", | |
" module = torch.jit.script(f_gen)\n", | |
"\n", | |
" # Serialize the compute graph to a file\n", | |
" module.save(save_filename)\n", | |
" \n", | |
" # Create the TorchForce from the serialized compute graph\n", | |
" from openmmtorch import TorchForce\n", | |
" torch_force = TorchForce(save_filename)\n", | |
" #torch_force.setForceGroup(1) #default 0th force group\n", | |
" torch_force.addGlobalParameter(torch_scale_name, torch_scale_default_value)\n", | |
" return torch_force" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b0cc5949-2edc-471a-9a84-6128f632b570", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ead45147-a5ee-4a2d-8f04-403c4334a806", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "887456b3-d89a-498c-89db-84593f54af20", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "912fec8e-aa6a-4392-8424-8553e8caf63d", | |
"metadata": {}, | |
"source": [ | |
"make integrators" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "1b8d6000-d1e2-43c8-b593-922475e0bba5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from openmmtools.integrators import LangevinIntegrator\n", | |
"old_int = LangevinIntegrator(temperature=T)\n", | |
"new_int = LangevinIntegrator(temperature=T)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5beb009f-5682-4bad-9f5f-ce1d4102dde8", | |
"metadata": {}, | |
"source": [ | |
"we only want to treat the ligand (the first residue index) with a torch-implemented force " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "efdc22b1-fdd4-465b-9711-285ec2b91549", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"_atoms = []\n", | |
"for res in topology.residues():\n", | |
" if res.index == 1:\n", | |
" for atom in res.atoms():\n", | |
" _atoms.append(atom.index)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "be3fa11f-c944-4cb7-8ccf-f251c5d212f9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"elements: ['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'O', 'C', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H']\n", | |
"species: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |
" 0, 0, 0, 0, 0, 0]])\n" | |
] | |
} | |
], | |
"source": [ | |
"torchforce = make_torchforce(topology = topology, \n", | |
" atoms = _atoms, \n", | |
" model_name='ani2x', \n", | |
" save_filename = 'animodel.pt', \n", | |
" torch_scale_name='torch_scale', \n", | |
" torch_scale_default_value=0.)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8d1d8fd4-29b3-4923-817b-f87bb1eee241", | |
"metadata": {}, | |
"source": [ | |
"add the force to the ml_system" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "63757e99-5cd2-4351-a4a4-c7b6ebb61e12", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"5" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_system.addForce(torchforce)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "14d76aba-6c2b-424a-84c4-a3eb12a312c3", | |
"metadata": {}, | |
"source": [ | |
"create contexts for each system, set positions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "0f8627c7-5625-4bab-855c-63abac2dfb03", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_context = openmm.Context(system, old_int)\n", | |
"ml_context = openmm.Context(ml_system, new_int)\n", | |
"\n", | |
"mm_context.setPositions(positions)\n", | |
"ml_context.setPositions(positions)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "77c0baeb-fd03-4b21-b08b-ed5f510dd8c4", | |
"metadata": {}, | |
"source": [ | |
"get the energies. they should be the same since the torch scaling factor is zero" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "a063e98e-f73e-436b-846a-26ad323fa0a5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_energy = mm_context.getState(getEnergy=True).getPotentialEnergy()\n", | |
"ml_energy = ml_context.getState(getEnergy=True).getPotentialEnergy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "2240dbcf-e677-49b7-ad28-d5cfff3d704a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Quantity(value=-51901.500729082625, unit=kilojoule/mole),\n", | |
" Quantity(value=-51901.500729082625, unit=kilojoule/mole))" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mm_energy, ml_energy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "07831865-b9d7-4605-898f-5c3d5bd82ad7", | |
"metadata": {}, | |
"source": [ | |
"the discrepancy is tolerable, i think." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1ff7326e-81b0-4b31-908e-a936f0ad6edc", | |
"metadata": {}, | |
"source": [ | |
"get the force matrices. the ml_forces should be identical to the mm forces" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "0dedb9a5-814d-4996-bb67-1efe3261d099", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mm_forces = mm_context.getState(getForces=True).getForces(asNumpy=True)\n", | |
"ml_forces = ml_context.getState(getForces=True).getForces(asNumpy=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "163f60be-ebcb-4795-bbc7-d5718346cea4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=array([[ -691.92553075, -2685.24385669, -959.2384907 ],\n", | |
" [ 1235.77752176, -1925.81224399, -1589.20924309],\n", | |
" [-1047.95250458, 871.6749267 , 2459.26792072],\n", | |
" ...,\n", | |
" [ -173.27987671, -38.89228821, -813.84509277],\n", | |
" [ 95.71780396, 153.92840576, 484.84619141],\n", | |
" [ 306.42907715, -62.71115494, 457.38729858]]), unit=kilojoule/(nanometer*mole))" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mm_forces" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "98cbaddf-cdcf-4620-b413-3426525479e4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=array([[ -77.9934845 , 91.68309021, -259.86657715],\n", | |
" [ 200.26818848, 212.14093018, -407.83255005],\n", | |
" [ -88.87184143, 433.3135376 , 50.88692474],\n", | |
" ...,\n", | |
" [-173.27993774, -38.89231873, -813.84539795],\n", | |
" [ 95.71783447, 153.92840576, 484.84622192],\n", | |
" [ 306.42907715, -62.71115875, 457.3873291 ]]), unit=kilojoule/(nanometer*mole))" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_forces" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b8b55bda-08ec-4e1c-b2be-238542e2c746", | |
"metadata": {}, | |
"source": [ | |
"these are very discrepant! what is the component-wise discrepancy?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "75352a45-662e-4e09-94e8-730ef5da7ef7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Quantity(value=array([[ 6.13932046e+02, 2.77692695e+03, 6.99371914e+02],\n", | |
" [-1.03550933e+03, 2.13795317e+03, 1.18137669e+03],\n", | |
" [ 9.59080663e+02, -4.38361389e+02, -2.40838100e+03],\n", | |
" ...,\n", | |
" [-6.10351562e-05, -3.05175781e-05, -3.05175781e-04],\n", | |
" [ 3.05175781e-05, 0.00000000e+00, 3.05175781e-05],\n", | |
" [ 0.00000000e+00, -3.81469727e-06, 3.05175781e-05]]), unit=kilojoule/(nanometer*mole))" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ml_forces - mm_forces" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f8451d06-c521-43a0-9b02-6ef1b8dca9fe", | |
"metadata": {}, | |
"source": [ | |
"this is a big problem." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "42a2f3a0-5564-4a65-a4fa-a8e504739f6b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment