Last active
July 30, 2019 18:46
-
-
Save wiederm/03b2ef6a9c8eb7e27b4211005090a9c8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from openmmtools.testsystems import DHFRExplicit | |
from openmmtools.testsystems import WaterBox | |
import matplotlib.pyplot as plt | |
from simtk import unit | |
import numpy as np | |
import torchani | |
import torch | |
import time | |
from simtk import unit | |
import sys | |
platform = 'cpu' | |
device = torch.device(platform) | |
model = torchani.models.ANI1ccx() | |
model = model.to(device) | |
torch.set_num_threads(2) | |
# openmm units | |
mass_unit = unit.dalton | |
distance_unit = unit.nanometer | |
time_unit = unit.femtosecond | |
energy_unit = unit.kilojoule_per_mole | |
speed_unit = distance_unit / time_unit | |
force_unit = unit.kilojoule_per_mole / unit.nanometer | |
# ANI-1 units and conversion factors | |
ani_distance_unit = unit.angstrom | |
hartree_to_kJ_mol = 2625.499638 | |
ani_energy_unit = hartree_to_kJ_mol * unit.kilojoule_per_mole # simtk.unit doesn't have hartree? | |
nm_to_angstroms = (1.0 * distance_unit) / (1.0 * ani_distance_unit) | |
nr_of_atoms = [] | |
times = [] | |
# testing different edge length | |
for n in [5, 10, 15, 25, 30, 35, 40]: | |
testsystem = WaterBox(box_edge=n*unit.angstrom) | |
x = testsystem.positions | |
top = testsystem.topology | |
dhfr_elements = [atom.element.symbol for atom in testsystem.topology.atoms()] | |
print('Nr of atoms in the system: {}'.format(len(dhfr_elements))) | |
element_string = ''.join(dhfr_elements) | |
element_string= element_string.replace('S', 'C') | |
species = model.species_to_tensor(element_string).to(device).unsqueeze(0) | |
coordinates = torch.tensor([x.value_in_unit(unit.nanometer)], | |
requires_grad=True, device=device, dtype=torch.float32) | |
t0 = time.time() | |
_, energy_in_hartree = model((species, coordinates * nm_to_angstroms)) | |
print('Energy in Hartree: {}'.format(energy_in_hartree)) | |
t1 = time.time() | |
t = t1-t0 | |
print('time: {}'.format(t)) | |
nr_of_atoms.append(len(dhfr_elements)) | |
times.append(t) | |
plt.plot(nr_of_atoms, times) | |
plt.title('Nr of atoms vs time') | |
plt.xlabel('Nr of atoms') | |
plt.ylabel('Time [sec]') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment