Skip to content

Instantly share code, notes, and snippets.

@wiederm
Last active July 30, 2019 18:46
Show Gist options
  • Save wiederm/03b2ef6a9c8eb7e27b4211005090a9c8 to your computer and use it in GitHub Desktop.
Save wiederm/03b2ef6a9c8eb7e27b4211005090a9c8 to your computer and use it in GitHub Desktop.
from openmmtools.testsystems import DHFRExplicit
from openmmtools.testsystems import WaterBox
import matplotlib.pyplot as plt
from simtk import unit
import numpy as np
import torchani
import torch
import time
from simtk import unit
import sys
platform = 'cpu'
device = torch.device(platform)
model = torchani.models.ANI1ccx()
model = model.to(device)
torch.set_num_threads(2)
# openmm units
mass_unit = unit.dalton
distance_unit = unit.nanometer
time_unit = unit.femtosecond
energy_unit = unit.kilojoule_per_mole
speed_unit = distance_unit / time_unit
force_unit = unit.kilojoule_per_mole / unit.nanometer
# ANI-1 units and conversion factors
ani_distance_unit = unit.angstrom
hartree_to_kJ_mol = 2625.499638
ani_energy_unit = hartree_to_kJ_mol * unit.kilojoule_per_mole # simtk.unit doesn't have hartree?
nm_to_angstroms = (1.0 * distance_unit) / (1.0 * ani_distance_unit)
nr_of_atoms = []
times = []
# testing different edge length
for n in [5, 10, 15, 25, 30, 35, 40]:
testsystem = WaterBox(box_edge=n*unit.angstrom)
x = testsystem.positions
top = testsystem.topology
dhfr_elements = [atom.element.symbol for atom in testsystem.topology.atoms()]
print('Nr of atoms in the system: {}'.format(len(dhfr_elements)))
element_string = ''.join(dhfr_elements)
element_string= element_string.replace('S', 'C')
species = model.species_to_tensor(element_string).to(device).unsqueeze(0)
coordinates = torch.tensor([x.value_in_unit(unit.nanometer)],
requires_grad=True, device=device, dtype=torch.float32)
t0 = time.time()
_, energy_in_hartree = model((species, coordinates * nm_to_angstroms))
print('Energy in Hartree: {}'.format(energy_in_hartree))
t1 = time.time()
t = t1-t0
print('time: {}'.format(t))
nr_of_atoms.append(len(dhfr_elements))
times.append(t)
plt.plot(nr_of_atoms, times)
plt.title('Nr of atoms vs time')
plt.xlabel('Nr of atoms')
plt.ylabel('Time [sec]')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment