Skip to content

Instantly share code, notes, and snippets.

@dominicrufa
Created April 27, 2021 03:41
Show Gist options
  • Save dominicrufa/a4e1ed2188de53fb74449e4ce7bee04d to your computer and use it in GitHub Desktop.
Save dominicrufa/a4e1ed2188de53fb74449e4ce7bee04d to your computer and use it in GitHub Desktop.
reproducing error wherein adding a `TorchForce` to a system gives force matrices that are discrepant w.r.t. mm-only system when the scaling factor is set to zero.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "95a44844-eff2-4546-a480-5ba0ea749154",
"metadata": {},
"source": [
"Reproducing strange force error wherein adding a torch force to system, setting the positions, and getting the force matrix yields major discrepancies in energy."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8e095550-10d2-4166-bfd0-4f1300a68e6c",
"metadata": {},
"outputs": [],
"source": [
"import simtk\n",
"from simtk import openmm\n",
"from simtk import unit\n",
"import copy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e17c9e97-40d3-4cea-8407-4aec0b4dd782",
"metadata": {},
"outputs": [],
"source": [
"from openmmtools.testsystems import HostGuestExplicit"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "db2101fe-5884-4794-9331-b3bf30ab6553",
"metadata": {},
"outputs": [],
"source": [
"T = 300*unit.kelvin"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b01e046a-3929-4353-9316-c5ab5b146341",
"metadata": {},
"outputs": [],
"source": [
"hge = HostGuestExplicit(constraints=None)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "493d272e-581f-4e8b-bfde-c7eb779ec0d3",
"metadata": {},
"outputs": [],
"source": [
"system, positions, topology = hge.system, hge.positions, hge.topology"
]
},
{
"cell_type": "markdown",
"id": "1c84f408-70ee-4aa3-aeaf-c338c921fe81",
"metadata": {},
"source": [
"make a deepcopy of the system to yield an ml_system that will have a torch force added to it"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6ea3de76-fc92-4fbb-a7d1-8aab7a237df5",
"metadata": {},
"outputs": [],
"source": [
"ml_system = copy.deepcopy(system)"
]
},
{
"cell_type": "markdown",
"id": "db31ebee-ce60-4697-8b9e-706163ca4b14",
"metadata": {},
"source": [
"wrap a simple torchani force that has a scaling factor that is a global parameter"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a22362da-4228-4251-97b8-b7ce23bd15a2",
"metadata": {},
"outputs": [],
"source": [
"def make_torchforce(topology, \n",
" atoms, \n",
" model_name='ani2x', \n",
" save_filename = 'animodel.pt', \n",
" torch_scale_name='torch_scale', \n",
" torch_scale_default_value=0.):\n",
" \"\"\"\n",
" creates a scalable (via a global parameter) openmm.TorchForce.\n",
" Note, the openmm.TorchForce is non periodic.\n",
" \n",
" arguments\n",
" topology : openmm.Topology\n",
" topology corresponding to the openmm.System object\n",
" atoms : list(int)\n",
" list of particle indices that will be included in the TorchForce\n",
" model_name : str, default `ani2x`\n",
" the name of the model that wille build the torchforce\n",
" save_filename : str, default `animodel.pt`\n",
" torch module name to save\n",
" torch_scale_name : str, default 'torch_scale'\n",
" the name of the global parameter that scales the TorchForce\n",
" torch_scale_default_value : float, default 1.\n",
" the default value of the `torch_scale_name`\n",
" \n",
" returns \n",
" f_gen : openmm.TorchForce\n",
" the generated TorchForce\n",
" \n",
" \"\"\"\n",
" import torch\n",
" import torchani\n",
" import openmmtorch\n",
" if model_name == 'ani1ccx':\n",
" model = torchani.models.ANI1ccx()\n",
" elif model_name == 'ani2x':\n",
" model = torchani.models.ANI2x()\n",
" else:\n",
" raise Exception(f\"model name {model_name} is not currently supported\")\n",
"\n",
" \n",
" # Create the PyTorch model that will be invoked by OpenMM.\n",
" includedAtoms = list(topology.atoms())\n",
" if atoms is not None:\n",
" includedAtoms = [includedAtoms[i] for i in atoms]\n",
" elements = [atom.element.symbol for atom in includedAtoms]\n",
" print(f\"elements: {elements}\")\n",
" species = model.species_to_tensor(elements).unsqueeze(0)\n",
" print(f\"species: {species}\")\n",
" #indices = torch.tensor(atoms, dtype=torch.int64) #get the atom indices which to pull\n",
" \n",
" class ANIForce(torch.nn.Module):\n",
" def __init__(self, indices, model, species):\n",
" super().__init__()\n",
" self.energyScale = torchani.units.hartree2kjoulemol(1)\n",
" self.indices = torch.tensor(indices, dtype=torch.int64)\n",
" #self.torch_indices = torch.tensor(self.indices)\n",
" self.model = model\n",
" self.species = species\n",
" \n",
" def forward(self, positions, scale):\n",
" #print(f\"indices to query: {self.indices}\")\n",
" positions = positions.to(torch.float32) #to float\n",
" # in_positions = torch.index_select(positions, 0, self.indices) #select the appropriate indices\n",
" in_positions = positions[self.indices]\n",
" # print(f\"in_positions has the following shape: {in_positions.shape}\")\n",
" _, energy = self.model((self.species, 10.0 * in_positions.unsqueeze(0))) #get the energy\n",
" #energy = _energy.sum()\n",
" out = energy * scale * self.energyScale\n",
" # print(f\"energy: {energy}; out: {out}\")\n",
" return out\n",
" \n",
" f_gen = ANIForce(atoms, model, species)\n",
" module = torch.jit.script(f_gen)\n",
"\n",
" # Serialize the compute graph to a file\n",
" module.save(save_filename)\n",
" \n",
" # Create the TorchForce from the serialized compute graph\n",
" from openmmtorch import TorchForce\n",
" torch_force = TorchForce(save_filename)\n",
" #torch_force.setForceGroup(1) #default 0th force group\n",
" torch_force.addGlobalParameter(torch_scale_name, torch_scale_default_value)\n",
" return torch_force"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0cc5949-2edc-471a-9a84-6128f632b570",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ead45147-a5ee-4a2d-8f04-403c4334a806",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "887456b3-d89a-498c-89db-84593f54af20",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "912fec8e-aa6a-4392-8424-8553e8caf63d",
"metadata": {},
"source": [
"make integrators"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1b8d6000-d1e2-43c8-b593-922475e0bba5",
"metadata": {},
"outputs": [],
"source": [
"from openmmtools.integrators import LangevinIntegrator\n",
"old_int = LangevinIntegrator(temperature=T)\n",
"new_int = LangevinIntegrator(temperature=T)"
]
},
{
"cell_type": "markdown",
"id": "5beb009f-5682-4bad-9f5f-ce1d4102dde8",
"metadata": {},
"source": [
"we only want to treat the ligand (the first residue index) with a torch-implemented force "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "efdc22b1-fdd4-465b-9711-285ec2b91549",
"metadata": {},
"outputs": [],
"source": [
"_atoms = []\n",
"for res in topology.residues():\n",
" if res.index == 1:\n",
" for atom in res.atoms():\n",
" _atoms.append(atom.index)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "be3fa11f-c944-4cb7-8ccf-f251c5d212f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"elements: ['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'O', 'C', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H']\n",
"species: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0]])\n"
]
}
],
"source": [
"torchforce = make_torchforce(topology = topology, \n",
" atoms = _atoms, \n",
" model_name='ani2x', \n",
" save_filename = 'animodel.pt', \n",
" torch_scale_name='torch_scale', \n",
" torch_scale_default_value=0.)"
]
},
{
"cell_type": "markdown",
"id": "8d1d8fd4-29b3-4923-817b-f87bb1eee241",
"metadata": {},
"source": [
"add the force to the ml_system"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "63757e99-5cd2-4351-a4a4-c7b6ebb61e12",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_system.addForce(torchforce)"
]
},
{
"cell_type": "markdown",
"id": "14d76aba-6c2b-424a-84c4-a3eb12a312c3",
"metadata": {},
"source": [
"create contexts for each system, set positions"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0f8627c7-5625-4bab-855c-63abac2dfb03",
"metadata": {},
"outputs": [],
"source": [
"mm_context = openmm.Context(system, old_int)\n",
"ml_context = openmm.Context(ml_system, new_int)\n",
"\n",
"mm_context.setPositions(positions)\n",
"ml_context.setPositions(positions)"
]
},
{
"cell_type": "markdown",
"id": "77c0baeb-fd03-4b21-b08b-ed5f510dd8c4",
"metadata": {},
"source": [
"get the energies. they should be the same since the torch scaling factor is zero"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a063e98e-f73e-436b-846a-26ad323fa0a5",
"metadata": {},
"outputs": [],
"source": [
"mm_energy = mm_context.getState(getEnergy=True).getPotentialEnergy()\n",
"ml_energy = ml_context.getState(getEnergy=True).getPotentialEnergy()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "2240dbcf-e677-49b7-ad28-d5cfff3d704a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Quantity(value=-51901.500729082625, unit=kilojoule/mole),\n",
" Quantity(value=-51901.500729082625, unit=kilojoule/mole))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mm_energy, ml_energy"
]
},
{
"cell_type": "markdown",
"id": "07831865-b9d7-4605-898f-5c3d5bd82ad7",
"metadata": {},
"source": [
"the discrepancy is tolerable, i think."
]
},
{
"cell_type": "markdown",
"id": "1ff7326e-81b0-4b31-908e-a936f0ad6edc",
"metadata": {},
"source": [
"get the force matrices. the ml_forces should be identical to the mm forces"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0dedb9a5-814d-4996-bb67-1efe3261d099",
"metadata": {},
"outputs": [],
"source": [
"mm_forces = mm_context.getState(getForces=True).getForces(asNumpy=True)\n",
"ml_forces = ml_context.getState(getForces=True).getForces(asNumpy=True)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "163f60be-ebcb-4795-bbc7-d5718346cea4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=array([[ -691.92553075, -2685.24385669, -959.2384907 ],\n",
" [ 1235.77752176, -1925.81224399, -1589.20924309],\n",
" [-1047.95250458, 871.6749267 , 2459.26792072],\n",
" ...,\n",
" [ -173.27987671, -38.89228821, -813.84509277],\n",
" [ 95.71780396, 153.92840576, 484.84619141],\n",
" [ 306.42907715, -62.71115494, 457.38729858]]), unit=kilojoule/(nanometer*mole))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mm_forces"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "98cbaddf-cdcf-4620-b413-3426525479e4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=array([[ -77.9934845 , 91.68309021, -259.86657715],\n",
" [ 200.26818848, 212.14093018, -407.83255005],\n",
" [ -88.87184143, 433.3135376 , 50.88692474],\n",
" ...,\n",
" [-173.27993774, -38.89231873, -813.84539795],\n",
" [ 95.71783447, 153.92840576, 484.84622192],\n",
" [ 306.42907715, -62.71115875, 457.3873291 ]]), unit=kilojoule/(nanometer*mole))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_forces"
]
},
{
"cell_type": "markdown",
"id": "b8b55bda-08ec-4e1c-b2be-238542e2c746",
"metadata": {},
"source": [
"these are very discrepant! what is the component-wise discrepancy?"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "75352a45-662e-4e09-94e8-730ef5da7ef7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Quantity(value=array([[ 6.13932046e+02, 2.77692695e+03, 6.99371914e+02],\n",
" [-1.03550933e+03, 2.13795317e+03, 1.18137669e+03],\n",
" [ 9.59080663e+02, -4.38361389e+02, -2.40838100e+03],\n",
" ...,\n",
" [-6.10351562e-05, -3.05175781e-05, -3.05175781e-04],\n",
" [ 3.05175781e-05, 0.00000000e+00, 3.05175781e-05],\n",
" [ 0.00000000e+00, -3.81469727e-06, 3.05175781e-05]]), unit=kilojoule/(nanometer*mole))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ml_forces - mm_forces"
]
},
{
"cell_type": "markdown",
"id": "f8451d06-c521-43a0-9b02-6ef1b8dca9fe",
"metadata": {},
"source": [
"this is a big problem."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42a2f3a0-5564-4a65-a4fa-a8e504739f6b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment