Skip to content

Instantly share code, notes, and snippets.

@proteneer
Last active July 30, 2019 14:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save proteneer/09ef52530f61d009219613fd3037b06c to your computer and use it in GitHub Desktop.
Save proteneer/09ef52530f61d009219613fd3037b06c to your computer and use it in GitHub Desktop.
Comparison of various ensemble stabilities
# processed from https://raw.githubusercontent.com/proteneer/timemachine/8ff69ee3c7bf248a7bdc5caf9fcae2e37fdb86ac/jax/harmonic_dij.py
import functools
import time
from tqdm import tqdm
import numpy as onp
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as np
import scipy.stats as stats
BOLTZMANN = 1.380658e-23
AVOGADRO = 6.0221367e23
RGAS = BOLTZMANN*AVOGADRO
BOLTZ = RGAS/1000
ONE_4PI_EPS0 = 138.935456
VIBRATIONAL_CONSTANT = 1302.79 # http://openmopac.net/manual/Hessian_Matrix.html
min_dij = np.array(999.9)
max_dij = np.array(0.0)
@jax.jit
def harmonic_bond_nrg(
coords,
params):
kb = params[0]
b0 = params[1]
src_idxs = [0]
dst_idxs = [1]
ci = coords[src_idxs]
cj = coords[dst_idxs]
dx = ci - cj
dij = np.linalg.norm(dx, axis=1)
energy = kb*np.power(dij-b0, 2)/2
return np.sum(energy)
@jax.jit
def harmonic_bond_grad(coords, params):
return jax.jacrev(harmonic_bond_nrg, argnums=(0,))
def set_velocities_to_temperature(n_atoms, temperature, masses):
v_t = onp.random.normal(size=(n_atoms, 3))
velocity_scale = onp.sqrt(BOLTZ*temperature/np.expand_dims(masses, -1))
return v_t*velocity_scale
def langevin_integrator(params, dt=0.0025, friction=1.0, temp=300.0, initial_temp=300.0):
x0 = np.array([
[-0.0036, 0.0222, 0.0912],
[-0.0162, -0.8092, 0.7960],
# [-0.1092, 0.9610, 0.6348],
# [-0.8292, -0.0852, -0.6123]
], dtype=np.float64)
x0 = x0/10
masses = np.array([12.0107, 1.0], dtype=np.float64)
num_atoms = len(masses)
num_dims = 3
dt = dt
v_t = set_velocities_to_temperature(x0.shape[0], initial_temp, masses)
friction = friction # dissipation speed (how fast we forget)
temperature = temp # temperature
vscale = np.exp(-dt*friction)
if friction == 0:
fscale = dt
else:
fscale = (1-vscale)/friction
kT = BOLTZ * temperature
nscale = np.sqrt(kT*(1-vscale*vscale)) # noise scale
invMasses = (1.0/masses).reshape((-1, 1))
sqrtInvMasses = np.sqrt(invMasses)
coeff_a = vscale
coeff_bs = fscale*invMasses
coeff_cs = nscale*sqrtInvMasses
KEs = []
max_PE = 0
harmonic_bond_grad = jax.jit(jax.jacrev(harmonic_bond_nrg, argnums=(0,)))
for step in tqdm(range(10000)):
g = harmonic_bond_grad(x0, params)
noise = onp.random.normal(size=(num_atoms, num_dims)).astype(x0.dtype)
v_t = vscale*v_t - fscale*invMasses*g[0] + nscale*sqrtInvMasses*noise
dx = v_t * dt
x0 += dx
return x0
if __name__ == "__main__":
theta = np.array([25000.0, 0.129], dtype=np.float64)
NVT = functools.partial(langevin_integrator, friction=1.0, temp=300.0, initial_temp=300.0)
NVE = functools.partial(langevin_integrator, friction=0.0, temp=0.0, initial_temp=300.0)
MIN = functools.partial(langevin_integrator, friction=1.0, temp=0.0, initial_temp=300.0)
for e in [NVT, NVE, MIN]:
dxdp = jax.jacfwd(e, argnums=(0,))
res = dxdp(theta)[0]
print(res, np.amax(res), np.amin(res))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment