Skip to content

Instantly share code, notes, and snippets.

@sschoenholz
Created February 11, 2021 23:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sschoenholz/b999533dda27a33b90e059e4353155a2 to your computer and use it in GitHub Desktop.
Save sschoenholz/b999533dda27a33b90e059e4353155a2 to your computer and use it in GitHub Desktop.
from jax_md import space
from jax import custom_jvp
from jax import lax
periodic_displacement = space.periodic_displacement
pairwise_displacement = space.pairwise_displacement
periodic_shift = space.periodic_shift
f32 = np.float32
def inverse(box):
if np.isscalar(box) or box.size == 1:
return 1 / box
elif box.ndim == 1:
return 1 / box
elif box.ndim == 2:
return np.linalg.inv(box)
raise ValueError()
def get_free_indices(n):
return ''.join([chr(ord('a') + i) for i in range(n)])
def base_transform(box, R):
if np.isscalar(box) or box.size == 1:
return R * box
elif box.ndim == 1:
indices = get_free_indices(R.ndim - 1) + 'i'
return np.einsum(f'i,{indices}->{indices}', box, R)
elif box.ndim == 2:
free_indices = get_free_indices(R.ndim - 1)
left_indices = free_indices + 'j'
right_indices = free_indices + 'i'
return np.einsum(f'ij,{left_indices}->{right_indices}', box, R)
raise ValueError()
@custom_jvp
def transform_without_tangents(box, R):
return base_transform(box, R)
@transform_without_tangents.defjvp
def transform_without_tangents_jvp(primals, tangents):
box, R = primals
dbox, dR = tangents
return (transform_without_tangents(box, R),
dR + transform_without_tangents(dbox, R))
def transform(box, R, fractional_coordinates=True):
if not fractional_coordinates:
return base_transform(box, R)
return transform_without_tangents(box, R)
def periodic_general(box, fractional_coordinates=True, wrapped=True):
inv_box = inverse(box)
def displacement_fn(Ra, Rb, **kwargs):
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
if not fractional_coordinates:
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
if not fractional_coordinates:
Ra = transform(_inv_box, Ra)
Rb = transform(_inv_box, Rb)
dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
return transform(_box, dR, fractional_coordinates=fractional_coordinates)
def u(R, dR):
if wrapped:
return periodic_shift(f32(1.0), R, dR)
return R + dR
def shift_fn(R, dR, **kwargs):
if not fractional_coordinates and not wrapped:
return R + dR
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
dR = transform(_inv_box, dR, fractional_coordinates=fractional_coordinates)
if not fractional_coordinates:
R = transform(_inv_box, R)
R = u(R, dR)
if not fractional_coordinates:
R = transform(_box, R)
return R
return displacement_fn, shift_fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment