Skip to content

Instantly share code, notes, and snippets.

@mjhong0708
Last active August 15, 2023 13:06
Show Gist options
  • Save mjhong0708/a4b92dc5b1a144365f30abe0a622b518 to your computer and use it in GitHub Desktop.
Save mjhong0708/a4b92dc5b1a144365f30abe0a622b518 to your computer and use it in GitHub Desktop.
Fast computation of MSD using jax and jax_md
"""Calculate mean square displacement (MSD) of a given element in a trajectory.
Use jax_md to handle PBC.
"""
from typing import List
import jax
import jax.numpy as jnp
import numpy as np
from ase import Atoms
from jax_md import space
from ase.data import atomic_numbers
def _check_images(images: List[Atoms]) -> None:
"""Check if images is a list of ase.Atoms objects of same system.
Args:
images (List[Atoms]): List of ase.Atoms objects.
Raises:
TypeError: If images is not a list of ase.Atoms objects.
ValueError: If images do not contain same model system.
"""
_type_err_msg = "images must be a list of ase.Atoms objects"
if not isinstance(images, list):
raise TypeError(_type_err_msg)
if not all(isinstance(image, Atoms) for image in images):
raise TypeError(_type_err_msg)
atoms_0 = images[0]
if not all(len(image) == len(atoms_0) for image in images):
raise ValueError("All images must have the same number of atoms")
cond1 = all(np.allclose(a.numbers, atoms_0.numbers) for a in images)
cond2 = all(np.allclose(a.cell.array, atoms_0.cell.array) for a in images)
if not (cond1 and cond2):
raise ValueError("All images must contain same model system")
def mean_square_displacement(images: List[Atoms], element: str) -> np.ndarray:
"""Calculate mean square displacement of a given element in a trajectory.
Args:
images (List[Atoms]): List of ase.Atoms objects.
element (str): Element symbol.
Raises:
TypeError: If images is not a list of ase.Atoms objects.
ValueError: If images do not contain same model system.
Returns:
np.ndarray: MSD of the element.
"""
_check_images(images)
# Get positions of selected element
elem_idx = np.where(images[0].numbers == atomic_numbers[element])[0]
all_pos_elem = []
for atoms in images:
all_pos_elem.append(atoms.get_scaled_positions()[elem_idx])
all_pos_elem = jnp.asarray(all_pos_elem) # (n_images, n_elem, 3)
# R(i) - R(i-1) for i = 1, ..., n_images
all_pos_elem_i = all_pos_elem[:-1]
all_pos_elem_j = all_pos_elem[1:]
displacement_fn, _ = space.periodic_general(
box=jnp.asarray(images[0].cell.array),
fractional_coordinates=True,
wrapped=True,
)
pairwise_displacement_fn = jax.jit(jax.vmap(jax.vmap(displacement_fn)))
disp_ij = pairwise_displacement_fn(all_pos_elem_i, all_pos_elem_j)
# Disp(i) = sum_{j=1}^{i} (R(j) - R(j-1))
disp_t = jnp.cumsum(jnp.asarray(disp_ij), axis=0)
# Compute average over elements
msd = jnp.mean(space.distance(disp_t) ** 2, axis=-1)
msd = jnp.concatenate([jnp.zeros((1,)), msd])
return jax.device_get(msd)
@mjhong0708
Copy link
Author

mjhong0708 commented Aug 15, 2023

Example

import os

# Comment out following line if you want to use GPU to accelerate computation
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

from msd_calc import mean_square_displacement

images = ... # example: images = ase.io.read("XDATCAR", ":")
msd = mean_square_displacement("Li")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment