Last active
August 15, 2023 13:06
-
-
Save mjhong0708/a4b92dc5b1a144365f30abe0a622b518 to your computer and use it in GitHub Desktop.
Fast computation of MSD using jax and jax_md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Calculate mean square displacement (MSD) of a given element in a trajectory. | |
Use jax_md to handle PBC. | |
""" | |
from typing import List | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from ase import Atoms | |
from jax_md import space | |
from ase.data import atomic_numbers | |
def _check_images(images: List[Atoms]) -> None: | |
"""Check if images is a list of ase.Atoms objects of same system. | |
Args: | |
images (List[Atoms]): List of ase.Atoms objects. | |
Raises: | |
TypeError: If images is not a list of ase.Atoms objects. | |
ValueError: If images do not contain same model system. | |
""" | |
_type_err_msg = "images must be a list of ase.Atoms objects" | |
if not isinstance(images, list): | |
raise TypeError(_type_err_msg) | |
if not all(isinstance(image, Atoms) for image in images): | |
raise TypeError(_type_err_msg) | |
atoms_0 = images[0] | |
if not all(len(image) == len(atoms_0) for image in images): | |
raise ValueError("All images must have the same number of atoms") | |
cond1 = all(np.allclose(a.numbers, atoms_0.numbers) for a in images) | |
cond2 = all(np.allclose(a.cell.array, atoms_0.cell.array) for a in images) | |
if not (cond1 and cond2): | |
raise ValueError("All images must contain same model system") | |
def mean_square_displacement(images: List[Atoms], element: str) -> np.ndarray: | |
"""Calculate mean square displacement of a given element in a trajectory. | |
Args: | |
images (List[Atoms]): List of ase.Atoms objects. | |
element (str): Element symbol. | |
Raises: | |
TypeError: If images is not a list of ase.Atoms objects. | |
ValueError: If images do not contain same model system. | |
Returns: | |
np.ndarray: MSD of the element. | |
""" | |
_check_images(images) | |
# Get positions of selected element | |
elem_idx = np.where(images[0].numbers == atomic_numbers[element])[0] | |
all_pos_elem = [] | |
for atoms in images: | |
all_pos_elem.append(atoms.get_scaled_positions()[elem_idx]) | |
all_pos_elem = jnp.asarray(all_pos_elem) # (n_images, n_elem, 3) | |
# R(i) - R(i-1) for i = 1, ..., n_images | |
all_pos_elem_i = all_pos_elem[:-1] | |
all_pos_elem_j = all_pos_elem[1:] | |
displacement_fn, _ = space.periodic_general( | |
box=jnp.asarray(images[0].cell.array), | |
fractional_coordinates=True, | |
wrapped=True, | |
) | |
pairwise_displacement_fn = jax.jit(jax.vmap(jax.vmap(displacement_fn))) | |
disp_ij = pairwise_displacement_fn(all_pos_elem_i, all_pos_elem_j) | |
# Disp(i) = sum_{j=1}^{i} (R(j) - R(j-1)) | |
disp_t = jnp.cumsum(jnp.asarray(disp_ij), axis=0) | |
# Compute average over elements | |
msd = jnp.mean(space.distance(disp_t) ** 2, axis=-1) | |
msd = jnp.concatenate([jnp.zeros((1,)), msd]) | |
return jax.device_get(msd) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example