Last active
September 11, 2023 14:29
-
-
Save dominicrufa/ef07679a76437d8601012fac58e18e7b to your computer and use it in GitHub Desktop.
a working toolbox to retrain espaloma to free energy data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
import numpy as np | |
import random | |
import click # necessary? | |
import glob | |
import torch | |
import espaloma as esp | |
import dgl | |
import logging | |
import typing | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) | |
# Constants | |
HARTREE_TO_KCALPERMOL = 627.5 | |
RANDOM_SEED = 2666 | |
TRAIN_RATIO, VAL_RATIO, TEST_RATIO = 0.8, 0.1, 0.1 | |
# openmm stuff | |
import openmm | |
from openmm import unit as u | |
from openmmtools.utils import get_fastest_platform | |
from openmmtools.constants import kB | |
PLATFORM = get_fastest_platform(minimum_precision = 'double') # or mixed? | |
TEMPERATURE = 300. * u.kelvin | |
FRICTION_COEFF = 1. / u.picoseconds | |
STEPSIZE = 4. * u.femtoseconds # too aggressive? | |
HARTREE_TO_REDUCED_UNITS = HARTREE_TO_KCALPERMOL * u.kilocalorie_per_mole / (kB * TEMPERATURE) | |
def extract_electrostatic_potential(context: openmm.Context, | |
indices_to_query: typing.Iterable[int]) -> np.array: | |
"""return the electrostatic_potential at each site of `indices_to_query`""" | |
sys = context.getSystem() | |
nbf = [i for i in sys.getForces() if i.__class__.__name__ == 'NonbondedForce'][0] | |
pe = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(u.md_unit_system) | |
out_elec_p = [] | |
for _idx in indices_to_query: | |
_c, _s, _e = nbf.getParticleParameters(_idx) # query particle params | |
_ = nbf.setParticleParameters(_idx, _c * 0., _s, _e) # zero charge | |
_ = nbf.updateParametersInContext(context) # update params to context | |
mod_pe = context.getState(getEnergy=True).getPotentialEnergy().value_in_unit_system(u.md_unit_system) # get pe | |
elec_potential = (pe - mod_pe) / _c.value_in_unit_system(u.md_unit_system) # compute elec_pot | |
out_elec_p.append(elec_potential) | |
_ = nbf.setParticleParameters(_idx, _c, _s, _e) # revert params from zero | |
_ = nbf.updateParametersInContext(context) # and update context accordingly | |
return np.array(out_elec_p) | |
def correction_extract_sampler( | |
system: openmm.System, | |
init_positions: u.nanometers, | |
indices_to_query: typing.Iterable[int], | |
num_iters: int = 1000, | |
steps_per_iter: int = 2500, | |
integrator_kwargs: typing.Dict[str, typing.Any] = {'temperature': TEMPERATURE, | |
'frictionCoeff': FRICTION_COEFF, | |
'stepSize': STEPSIZE}, | |
minimizer_kwargs: typing.Dict[str, typing.Any] = {'tolerance': 10, 'maxIterations': 1000}, # sufficient, idk | |
) -> Dict[str, np.array]: | |
"""a simple minimizer/langevin sampler to run (BAOAB) dynamics on an `openmm.System` object and collect: | |
1. potential energies of each snapshot, | |
2. positions of `indices_to_query` | |
3. electrostatic_potential at each of the `indices_to_query` | |
# NOTES | |
1. harcoding deterministic forces in `openmm.Context` | |
a dictionary is returned containing the aforementioned information; all units are in `openmm.unit.md_unit_system` | |
""" | |
from openmm import LangevinMiddleIntegrator, Context, LocalEnergyMinimizer | |
integrator = LangevinMiddleIntegrator(**integrator_kwargs) | |
context = openmm.Context(system, integrator, PLATFORM, {'DeterministicForces': True}) | |
context.setPositions(init_positions) | |
# minimize | |
LocalEnergyMinimizer.minimize(context, **minimizer_kwargs) | |
# thermalize velocities | |
_temp = integrator_kwargs.get('temperature', TEMPERATURE) | |
context.setVelocitiesToTemperature(_temp) | |
# collectors | |
pes, posits, elec_pot = [], [], [] | |
for _iter in range(num_iters): | |
integrator.step(steps_per_iter) | |
state = context.getState(getPositions=True, getEnergy=True) | |
pes.append(state.getPotentialEnergy().value_in_unit_system(u.md_unit_system)) | |
posits.append(state.getPositions(asNumpy=True).value_in_unit_system(u.md_unit_system)[indices_to_query,:]) | |
elec_pot.append(extract_electrostatic_potential(context, indices_to_query)) | |
del context | |
return {'potential_energy': np.array(pes), 'positions': np.array(posits), 'electrostatic_potential': np.array(elec_pot)} | |
def _load_datasets( | |
datasets: typing.Iterable[str], | |
input_prefix: str) -> esp.data.dataset.GraphDataset: # ?, idk if this is the formal definition of the object, check | |
""" | |
Load unique molecules (nonisomeric smile). | |
""" | |
logging.debug(f"# LOAD UNIQUE MOLECULES") | |
for i, dataset in enumerate(datasets): | |
path = os.path.join(input_prefix, dataset) | |
# RNA-nucleoside handled as training set since it only contains 4 entries. | |
# should I use a kwargs to handle this differently if we don't want to use? | |
if dataset == "rna-nucleoside": | |
_ds_tr = esp.data.dataset.GraphDataset.load(path) | |
else: | |
ds = esp.data.dataset.GraphDataset.load(path).shuffle(RANDOM_SEED) | |
_ds_tr, _ds_vl, _ds_te = ds.split([TRAIN_RATIO, VAL_RATIO, TEST_RATIO]) | |
# Merge datasets | |
if i == 0: | |
ds_tr = _ds_tr | |
else: | |
ds_tr += _ds_tr | |
logging.debug(f"{dataset}: {len(_ds_tr)} entries (total: {len(ds_tr)})") | |
del _ds_tr, _ds_vl, _ds_te | |
return ds_tr | |
def _load_duplicate_datasets(ds_tr, input_prefix): | |
""" | |
Load duplicated molecules (isomeric smiles) from different datasets | |
to avoid overlapping molecules in train, validate, test dataset. | |
""" | |
entries = glob.glob(os.path.join(input_prefix, "duplicated-isomeric-smiles-merge", "*")) | |
random.seed(RANDOM_SEED) | |
random.shuffle(entries) | |
n_entries = len(entries) | |
entries_tr = entries[:int(n_entries*TRAIN_RATIO)] | |
entries_vl = entries[int(n_entries*TRAIN_RATIO):int(n_entries*TRAIN_RATIO)+int(n_entries*VAL_RATIO)] | |
entries_te = entries[int(n_entries*TRAIN_RATIO)+int(n_entries*VAL_RATIO):] | |
logging.debug(f"Found {n_entries} entries. Split data into {len(entries_tr)}:{len(entries_vl)}:{len(entries_te)} entries.") | |
assert n_entries == len(entries_tr) + len(entries_vl) + len(entries_te) | |
for entry in entries_tr: | |
_datasets = os.listdir(entry) | |
for _dataset in _datasets: | |
_ds_tr = esp.data.dataset.GraphDataset.load(os.path.join(entry, _dataset)) | |
ds_tr += _ds_tr | |
del _ds_tr | |
return ds_tr | |
def _fn(g): # this is a mod in place? | |
""" | |
Remove unnecessary data from graph; this is to accommodate memory requirements? check | |
""" | |
g.nodes['g'].data.pop('u_qm') | |
g.nodes['g'].data.pop('u_gaff-1.81') | |
g.nodes['g'].data.pop('u_gaff-2.11') | |
g.nodes['g'].data.pop('u_openff-1.2.0') | |
g.nodes['g'].data.pop('u_openff-2.0.0') | |
g.nodes['n1'].data.pop('u_qm_prime') | |
g.nodes['n1'].data.pop('u_gaff-1.81_prime') | |
g.nodes['n1'].data.pop('u_gaff-2.11_prime') | |
g.nodes['n1'].data.pop('u_openff-1.2.0_prime') | |
g.nodes['n1'].data.pop('u_openff-2.0.0_prime') | |
try: # why not `try` everything if the data are not homogeneous? | |
g.nodes['g'].data.pop('u_amber14') | |
g.nodes['n1'].data.pop('u_amber14_prime') | |
except: | |
pass | |
# Remove u_ref_relative. u_ref_relative will be recalculated after handling heterographs with different size | |
g.nodes['g'].data.pop('u_ref_relative') | |
g.nodes['g'].data['u_ref'] = g.nodes['g'].data['u_ref'].double() | |
g.nodes['n1'].data['q_ref'] = g.nodes['n1'].data['q_ref'].float() | |
return g | |
def _augment_conformations(ds_tr, n_max_confs): | |
""" | |
Augment conformations to handle heterographs. | |
This is a work around to handle different graph size (shape). DGL requires at least one dimension with same size. | |
Here, we will modify the graphs so that each graph has the same number of conformations instead fo concatenating | |
graphs into heterogenous graphs with the same number of conformations. This will allow batching and shuffling | |
during the training. | |
""" | |
_ds_tr = [] | |
for i, g in enumerate(ds_tr): | |
n = g.nodes['n1'].data['xyz'].shape[1] | |
#logging.debug(f">{i}: {n} conformations") | |
if n == n_max_confs: | |
# Calculate u_ref_relative | |
g.nodes['g'].data['u_ref_relative'] = g.nodes['g'].data['u_ref'].detach().clone() | |
g.nodes['g'].data['u_ref_relative'] = g.nodes['g'].data['u_ref_relative'] - g.nodes['g'].data['u_ref_relative'].mean(dim=-1, keepdims=True) | |
g.nodes['g'].data['u_ref_relative'] = g.nodes['g'].data['u_ref_relative'].float() | |
g.nodes['g'].data.pop('u_ref') | |
_ds_tr.append(g.heterograph) | |
elif n < n_max_confs: | |
random.seed(RANDOM_SEED) | |
index = random.choices(range(0, n), k=n_max_confs-n) | |
#logging.debug(f"Randomly select {len(index)} conformers") | |
import copy | |
_g = copy.deepcopy(g) | |
#print(index) | |
a = torch.cat((_g.nodes['g'].data['u_ref'], _g.nodes['g'].data['u_ref'][:, index]), dim=-1) | |
b = torch.cat((_g.nodes['n1'].data['xyz'], _g.nodes['n1'].data['xyz'][:, index, :]), dim=1) | |
c = torch.cat((_g.nodes['n1'].data['u_ref_prime'], _g.nodes['n1'].data['u_ref_prime'][:, index, :]), dim=1) | |
# Update in place | |
_g.nodes["g"].data["u_ref"] = a | |
_g.nodes["n1"].data["xyz"] = b | |
_g.nodes['n1'].data['u_ref_prime'] = c | |
# Calculate u_ref_relative | |
_g.nodes['g'].data['u_ref_relative'] = _g.nodes['g'].data['u_ref'].detach().clone() | |
_g.nodes['g'].data['u_ref_relative'] = _g.nodes['g'].data['u_ref_relative'] - _g.nodes['g'].data['u_ref_relative'].mean(dim=-1, keepdims=True) | |
_g.nodes['g'].data['u_ref_relative'] = _g.nodes['g'].data['u_ref_relative'].float() | |
_g.nodes['g'].data.pop('u_ref') | |
_ds_tr.append(_g.heterograph) | |
else: | |
random.seed(RANDOM_SEED) | |
idx_range = random.sample(range(n), k=n) | |
for j in range(n // n_max_confs + 1): | |
import copy | |
_g = copy.deepcopy(g) | |
if (j+1)*n_max_confs > n: | |
_index = range(j*n_max_confs, n) | |
random.seed(RANDOM_SEED) | |
index = random.choices(range(0, n), k=(j+1)*n_max_confs-n) | |
#logging.debug(f"Iteration {j}: Randomly select {len(index)} conformers") | |
a = torch.cat((_g.nodes['g'].data['u_ref'][:, index], _g.nodes['g'].data['u_ref'][:, _index]), dim=-1) | |
b = torch.cat((_g.nodes['n1'].data['xyz'][:, index, :], _g.nodes['n1'].data['xyz'][:, _index, :]), dim=1) | |
c = torch.cat((_g.nodes['n1'].data['u_ref_prime'][:, index, :], _g.nodes['n1'].data['u_ref_prime'][:, _index, :]), dim=1) | |
else: | |
idx1 = j*n_max_confs | |
idx2 = (j+1)*n_max_confs | |
_index = idx_range[idx1:idx2] | |
#logging.debug(f"Iteration {j}: Extract indice from {idx1} to {idx2}") | |
a = _g.nodes['g'].data['u_ref'][:, _index] | |
b = _g.nodes['n1'].data['xyz'][:, _index, :] | |
c = _g.nodes['n1'].data['u_ref_prime'][:, _index, :] | |
# Update in place | |
_g.nodes["g"].data["u_ref"] = a | |
_g.nodes["n1"].data["xyz"] = b | |
_g.nodes["n1"].data["u_ref_prime"] = c | |
# Calculate u_ref_relative | |
_g.nodes['g'].data['u_ref_relative'] = _g.nodes['g'].data['u_ref'].detach().clone() | |
_g.nodes['g'].data['u_ref_relative'] = _g.nodes['g'].data['u_ref_relative'] - _g.nodes['g'].data['u_ref_relative'].mean(dim=-1, keepdims=True) | |
_g.nodes['g'].data['u_ref_relative'] = _g.nodes['g'].data['u_ref_relative'].float() | |
_g.nodes['g'].data.pop('u_ref') | |
_ds_tr.append(_g.heterograph) | |
return _ds_tr | |
def conv_config_to_list(config: typing.Iterable[str]) -> typing.Iterable[typing.Union[int, str]]: | |
"""convert `config` argument to list of int or str, though not sure why...""" | |
_config = [] | |
for _ in config.split(): # split the config | |
try: | |
_config.append(int(_)) | |
except: | |
_config.append(str(_)) | |
return _config | |
def prep_qm_dataset( | |
datasets: typing.Iterable[str], | |
prefix: str, | |
n_max_confs: int) -> esp.data.dataset.GraphDataset: | |
"""do some miscellaneous dataset parsing/prepping. exact modifications are listed below""" | |
from espaloma.graphs.utils.regenerate_impropers import regenerate_impropers | |
# parse data/load duplicates | |
datasets = [str(_) for _ in datasets.split()] | |
ds_tr = _load_datasets(datasets, prefix) | |
ds_tr = _load_duplicate_datasets(ds_tr, prefix) | |
# Remove unnecessary data from graph; | |
ds_tr.apply(_fn, in_place=True) | |
ds_tr.apply(regenerate_impropers, in_place=True) | |
ds_tr_augment = _augment_conformations(ds_tr, n_max_confs) | |
del ds_tr | |
return ds_tr_augment | |
def prep_exp_sim_dataset( | |
datasets: typing.Iterable[str], | |
prefix: str) -> esp.data.dataset.GraphDataset: | |
"""load exptl data""" | |
from espaloma.graphs.utils.regenerate_impropers import regenerate_impropers | |
# parse data/load duplicates | |
datasets = [str(_) for _ in datasets.split()] | |
ds_tr = _load_datasets(datasets, prefix) | |
ds_tr = _load_duplicate_datasets(ds_tr, prefix) | |
ds_tr.apply(regenerate_impropers, in_place=True) | |
return ds_tr | |
class GetLoss(torch.nn.Module): | |
"""this defines the total loss function.""" | |
def energy_loss(self, g, **kwargs): | |
return torch.nn.MSELoss()( | |
g.nodes['g'].data['u'] - g.nodes['g'].data['u'].mean(dim=-1, keepdims=True), | |
g.nodes['g'].data['u_ref_relative'], | |
) | |
def charge_loss(self, g, **kwargs): | |
return torch.nn.MSELoss()( | |
g.nodes['n1'].data['q'], | |
g.nodes['n1'].data['q_ref'], | |
) | |
def force_loss(self, g, **kwargs): | |
du_dx_hat = torch.autograd.grad( | |
g.nodes['g'].data['u'].sum(), | |
g.nodes['n1'].data['xyz'], | |
create_graph=True, | |
retain_graph=True, | |
allow_unused=True, | |
)[0] | |
du_dx = g.nodes["n1"].data["u_ref_prime"] | |
return torch.nn.MSELoss()( | |
du_dx, | |
du_dx_hat | |
) | |
def fep_loss(self, g, **kwargs): | |
""" | |
fep loss is the difference between the corrected free energy and the experimental free energy; | |
the datastructure is tentative | |
""" | |
# compute experimental/fep free energy differences | |
exptl_reduced_free_energy = g.nodes['g'].data['exp_fe_complex'] - g.nodes['g'].data['exp_fe_solvent'] | |
# compute complex/solvent energy differences | |
complex_energy_diffs = HARTREE_TO_REDUCED_UNITS * g.nodes['g'].data['u_complex'] - g.nodes['g'].data['u_complex_ref'] | |
solvent_energy_diffs = HARTREE_TO_REDUCED_UNITS * g.nodes['g'].data['u_solvent'] - g.nodes['g'].data['u_solvent_ref'] | |
# compute complex/solvent corrections | |
correction_complex = -torch.log(torch.exp(-complex_energy_diffs).mean()) | |
correction_solvent = -torch.log(torch.exp(-solvent_energy_diffs).mean()) | |
# compute correction | |
corrected_complex_fep_energy = g.nodes['g'].data['fep_fe_complex'] + correction_complex | |
corrected_solvent_fep_energy = g.nodes['g'].data['fep_fe_solvent'] + correction_solvent | |
# compute loss | |
loss = (exptl_reduced_free_energy - (corrected_complex_fep_energy - corrected_solvent_fep_energy))**2 | |
def forward(self, g, charge_weight, energy_weight, force_weight, fep_weight): | |
loss = (self.charge_loss(g) * charge_weight | |
+ self.energy_loss(g) * energy_weight | |
+ self.force_loss(g) * force_weight) | |
# improper mod | |
if g.number_of_nodes('n4_improper') > 0: | |
loss = loss + g.nodes['n4_improper'].data['k'].pow(2).mean() | |
if g.number_of_nodes('n4') > 0: | |
loss = loss + g.nodes['n4'].data['k'].pow(2).mean() | |
# fep mod | |
if 'exp_fep_complex' in g.nodes['g'].data.keys(): | |
loss += (self.fep_loss(g) * fep_weight) | |
return loss | |
def run(kwargs): # what is this type of gross parsing? check | |
epochs = kwargs['epochs'] | |
batch_size = kwargs['batch_size'] | |
layer = kwargs['layer'] | |
units = kwargs['units'] | |
config = kwargs['config'] | |
janossy_config = kwargs['janossy_config'] | |
learning_rate = kwargs['learning_rate'] | |
output_prefix = kwargs['output_prefix'] | |
input_prefix = kwargs['input_prefix'] | |
qm_datasets = kwargs['qm_datasets'] | |
exp_sim_datasets = kwargs['exp_sim_datasets'] | |
n_max_confs = kwargs['n_max_confs'] | |
force_weight = kwargs['force_weight'] | |
# Convert config and janossy_config into list | |
config = conv_config_to_list(config) | |
janossy_config = conv_config_to_list(janossy_config) | |
qm_ds_tr_augment = prep_qm_dataset(qm_datasets, input_prefix, n_max_confs) | |
exp_sim_ds_tr = prep_qm_dataset(exp_sim_datasets, input_prefix) | |
# # Convert datasets into list | |
# qm_datasets = [str(_) for _ in qm_datasets.split()] | |
# # Load datasets | |
# logging.debug(f"# LOAD DUPLICATED MOLECULES") | |
# qm_ds_tr = _load_datasets(qm_datasets, input_prefix) | |
# logging.debug(f"# Training size is now: {len(qm_ds_tr)}") | |
# # Load duplicate datasets | |
# qm_ds_tr = _load_duplicate_datasets(qm_ds_tr, input_prefix) | |
# logging.debug(f"# Training size is now: {len(qm_ds_tr)}") | |
# # Remove unnecessary data from graph; | |
# from espaloma.graphs.utils.regenerate_impropers import regenerate_impropers | |
# ds_tr.apply(_fn, in_place=True) | |
# ds_tr.apply(regenerate_impropers, in_place=True) | |
# # Handle heterographs | |
# logging.debug(f"# AUGMENT CONFORMATIONS TO HANDLE HETEROGRAPHS") | |
# ds_tr_augment = _augment_conformations(ds_tr, n_max_confs) | |
# logging.debug(f"# Training size is now: {len(ds_tr_augment)}") | |
# del ds_tr | |
# | |
# Define espaloma model | |
# | |
# Representation | |
#layer = esp.nn.layers.dgl_legacy.gn(layer) | |
layer = esp.nn.layers.dgl_legacy.gn(layer, {"aggregator_type": "mean", "feat_drop": 0.1}) # should hardcode? | |
representation = esp.nn.Sequential(layer, config=config) | |
# out_features: Define modular MM parameters Espaloma will assign | |
# 1: atom hardness and electronegativity | |
# 2: bond linear combination, enforce positive | |
# 3: angle linear combination, enforce positive | |
# 4: torsion barrier heights (can be positive or negative) | |
readout = esp.nn.readout.janossy.JanossyPooling( | |
in_features=units, config=janossy_config, | |
out_features={ | |
1: {'s': 1, 'e': 1}, | |
2: {'log_coefficients': 2}, | |
3: {'log_coefficients': 2}, | |
4: {'k': 6}, | |
}, | |
) | |
readout_improper = esp.nn.readout.janossy.JanossyPoolingWithSmirnoffImproper(in_features=units, config=janossy_config, out_features={"k": 2}) | |
class ExpCoeff(torch.nn.Module): | |
def forward(self, g): | |
g.nodes['n2'].data['coefficients'] = g.nodes['n2'].data['log_coefficients'].exp() | |
g.nodes['n3'].data['coefficients'] = g.nodes['n3'].data['log_coefficients'].exp() | |
return g | |
net = torch.nn.Sequential( | |
representation, | |
readout, | |
readout_improper, | |
ExpCoeff(), | |
esp.nn.readout.charge_equilibrium.ChargeEquilibrium(), | |
esp.mm.geometry.GeometryInGraph(), | |
esp.mm.energy.EnergyInGraph(terms=["n2", "n3", "n4", "n4_improper"]), | |
GetLoss(), | |
).cuda() | |
# Check if checkpoint file exists | |
checkpoints = glob.glob("{}/*.th".format(output_prefix)) | |
if checkpoints: | |
n = [ int(c.split('net')[1].split('.')[0]) for c in checkpoints ] | |
n.sort() | |
last_step = n[-1] | |
last_checkpoint = os.path.join(output_prefix, "net{}.th".format(last_step)) | |
net.load_state_dict(torch.load(last_checkpoint)) | |
step = last_step + 1 | |
print('Found checkpoint file ({}). Restrating from step {}'.format(last_checkpoint, step)) | |
else: | |
step = 1 | |
# Train | |
qm_ds_tr_loader = dgl.dataloading.GraphDataLoader(qm_ds_tr_augment, batch_size=batch_size, shuffle=True) | |
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) | |
with torch.autograd.set_detect_anomaly(True): | |
for idx in range(step, step+epochs): | |
for g in qm_ds_tr_loader: | |
optimizer.zero_grad() | |
g = g.to("cuda:0") | |
g.nodes["n1"].data["xyz"].requires_grad = True | |
loss = net(g) | |
loss.backward() | |
optimizer.step() | |
if idx % 10 == 0: | |
# Note: returned loss is a joint loss of different units. | |
print(idx, HARTREE_TO_KCALPERMOL * loss.pow(0.5).item()) | |
if not os.path.exists(output_prefix): | |
os.mkdir(output_prefix) | |
torch.save(net.state_dict(), output_prefix + "/net%s.th" % idx) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment