Skip to content

Instantly share code, notes, and snippets.

@Linux-cpp-lisp
Created April 15, 2022 02:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Linux-cpp-lisp/692018c74b3906b63529e60619f5a207 to your computer and use it in GitHub Desktop.
Save Linux-cpp-lisp/692018c74b3906b63529e60619f5a207 to your computer and use it in GitHub Desktop.
translating ASE periodic neighborlist into TorchScript compatible PyTorch (MOSTLY UNTESTED)
# Author: https://gist.github.com/Linux-cpp-lisp
from typing import Tuple, List
import numpy as np
import torch
@torch.jit.script
def torch_divmod(a, b) -> Tuple[torch.Tensor, torch.Tensor]:
d = torch.div(a, b, rounding_mode="floor")
m = a % b
return d, m
@torch.jit.script
def primitive_neighbor_list(
quantities: str,
pbc: Tuple[bool, bool, bool],
cell,
positions,
cutoff: float,
self_interaction: bool = False,
use_scaled_positions: bool = False,
max_nbins: int = int(1e6),
) -> List[torch.Tensor]:
"""Compute a neighbor list for an atomic configuration.
Atoms outside periodic boundaries are mapped into the box. Atoms
outside nonperiodic boundaries are included in the neighbor list
but complexity of neighbor list search for those can become n^2.
The neighbor list is sorted by first atom index 'i', but not by second
atom index 'j'.
Parameters:
quantities: str
Quantities to compute by the neighbor list algorithm. Each character
in this string defines a quantity. They are returned in a tuple of
the same order. Possible quantities are
* 'i' : first atom index
* 'j' : second atom index
* 'd' : absolute distance
* 'D' : distance vector
* 'S' : shift vector (number of cell boundaries crossed by the bond
between atom i and j). With the shift vector S, the
distances D between atoms can be computed from:
D = positions[j]-positions[i]+S.dot(cell)
pbc: array_like
3-tuple indicating giving periodic boundaries in the three Cartesian
directions.
cell: 3x3 matrix
Unit cell vectors. Must be completed.
positions: list of xyz-positions
Atomic positions. Anything that can be converted to an ndarray of
shape (n, 3) will do: [(x1,y1,z1), (x2,y2,z2), ...]. If
use_scaled_positions is set to true, this must be scaled positions.
cutoff: float or dict
Cutoff for neighbor search. It can be:
* A single float: This is a global cutoff for all elements.
* A dictionary: This specifies cutoff values for element
pairs. Specification accepts element numbers of symbols.
Example: {(1, 6): 1.1, (1, 1): 1.0, ('C', 'C'): 1.85}
* A list/array with a per atom value: This specifies the radius of
an atomic sphere for each atoms. If spheres overlap, atoms are
within each others neighborhood. See :func:`~ase.neighborlist.natural_cutoffs`
for an example on how to get such a list.
self_interaction: bool
Return the atom itself as its own neighbor if set to true.
Default: False
use_scaled_positions: bool
If set to true, positions are expected to be scaled positions.
max_nbins: int
Maximum number of bins used in neighbor search. This is used to limit
the maximum amount of memory required by the neighbor list.
Returns:
i, j, ... : array
Tuple with arrays for each quantity specified above. Indices in `i`
are returned in ascending order 0..len(a)-1, but the order of (i,j)
pairs is not guaranteed.
"""
# Naming conventions: Suffixes indicate the dimension of an array. The
# following convention is used here:
# c: Cartesian index, can have values 0, 1, 2
# i: Global atom index, can have values 0..len(a)-1
# xyz: Bin index, three values identifying x-, y- and z-component of a
# spatial bin that is used to make neighbor search O(n)
# b: Linearized version of the 'xyz' bin index
# a: Bin-local atom index, i.e. index identifying an atom *within* a
# bin
# p: Pair index, can have value 0 or 1
# n: (Linear) neighbor index
# Return empty neighbor list if no atoms are passed here
if len(positions) == 0:
assert False
# Compute reciprocal lattice vectors.
recip_cell = torch.linalg.pinv(cell).T
b1_c, b2_c, b3_c = recip_cell[0], recip_cell[1], recip_cell[2]
# Compute distances of cell faces.
l1 = torch.linalg.norm(b1_c)
l2 = torch.linalg.norm(b2_c)
l3 = torch.linalg.norm(b3_c)
pytorch_scalar_1 = torch.as_tensor(1.0)
face_dist_c = torch.hstack(
[
1 / l1 if l1 > 0 else pytorch_scalar_1,
1 / l2 if l2 > 0 else pytorch_scalar_1,
1 / l3 if l3 > 0 else pytorch_scalar_1,
]
)
assert face_dist_c.shape == (3,)
# we don't handle other fancier cutoffs
max_cutoff: float = cutoff
# We use a minimum bin size of 3 A
bin_size = max(max_cutoff, 3)
# Compute number of bins such that a sphere of radius cutoff fits into
# eight neighboring bins.
nbins_c = torch.maximum(
(face_dist_c / bin_size).to(dtype=torch.long), torch.ones(3, dtype=torch.long)
)
nbins = torch.prod(nbins_c)
# Make sure we limit the amount of memory used by the explicit bins.
while nbins > max_nbins:
nbins_c = torch.maximum(nbins_c // 2, torch.ones(3, dtype=torch.long))
nbins = torch.prod(nbins_c)
# Compute over how many bins we need to loop in the neighbor list search.
neigh_search = torch.ceil(bin_size * nbins_c / face_dist_c).to(dtype=torch.long)
neigh_search_x, neigh_search_y, neigh_search_z = (
neigh_search[0],
neigh_search[1],
neigh_search[2],
)
# If we only have a single bin and the system is not periodic, then we
# do not need to search neighboring bins
pytorch_scalar_int_0 = torch.as_tensor(0, dtype=torch.long)
neigh_search_x = (
pytorch_scalar_int_0 if nbins_c[0] == 1 and not pbc[0] else neigh_search_x
)
neigh_search_y = (
pytorch_scalar_int_0 if nbins_c[1] == 1 and not pbc[1] else neigh_search_y
)
neigh_search_z = (
pytorch_scalar_int_0 if nbins_c[2] == 1 and not pbc[2] else neigh_search_z
)
# Sort atoms into bins.
if use_scaled_positions:
scaled_positions_ic = positions
positions = torch.dot(scaled_positions_ic, cell)
else:
scaled_positions_ic = torch.linalg.solve(cell.T, positions.T).T
bin_index_ic = torch.floor(scaled_positions_ic * nbins_c).to(dtype=torch.long)
cell_shift_ic = torch.zeros_like(bin_index_ic)
for c in range(3):
if pbc[c]:
# (Note: torch.divmod does not exist in older numpies)
cell_shift_ic[:, c], bin_index_ic[:, c] = torch_divmod(
bin_index_ic[:, c], nbins_c[c]
)
else:
bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, nbins_c[c] - 1)
# Convert Cartesian bin index to unique scalar bin index.
bin_index_i = bin_index_ic[:, 0] + nbins_c[0] * (
bin_index_ic[:, 1] + nbins_c[1] * bin_index_ic[:, 2]
)
# atom_i contains atom index in new sort order.
atom_i = torch.argsort(bin_index_i)
bin_index_i = bin_index_i[atom_i]
# Find max number of atoms per bin
max_natoms_per_bin = torch.bincount(bin_index_i).max()
# Sort atoms into bins: atoms_in_bin_ba contains for each bin (identified
# by its scalar bin index) a list of atoms inside that bin. This list is
# homogeneous, i.e. has the same size *max_natoms_per_bin* for all bins.
# The list is padded with -1 values.
atoms_in_bin_ba = -torch.ones(nbins, max_natoms_per_bin.item(), dtype=torch.long)
for i in range(int(max_natoms_per_bin.item())):
# Create a mask array that identifies the first atom of each bin.
mask = torch.cat(
(torch.ones(1, dtype=torch.bool), bin_index_i[:-1] != bin_index_i[1:]),
dim=0,
)
# Assign all first atoms.
atoms_in_bin_ba[bin_index_i[mask], i] = atom_i[mask]
# Remove atoms that we just sorted into atoms_in_bin_ba. The next
# "first" atom will be the second and so on.
mask = torch.logical_not(mask)
atom_i = atom_i[mask]
bin_index_i = bin_index_i[mask]
# Make sure that all atoms have been sorted into bins.
assert len(atom_i) == 0
assert len(bin_index_i) == 0
# Now we construct neighbor pairs by pairing up all atoms within a bin or
# between bin and neighboring bin. atom_pairs_pn is a helper buffer that
# contains all potential pairs of atoms between two bins, i.e. it is a list
# of length max_natoms_per_bin**2.
# atom_pairs_pn_np = np.indices(
# (max_natoms_per_bin, max_natoms_per_bin), dtype=int
# ).reshape(2, -1)
atom_pairs_pn = torch.cartesian_prod(
torch.arange(max_natoms_per_bin), torch.arange(max_natoms_per_bin)
)
atom_pairs_pn = atom_pairs_pn.T.reshape(2, -1)
# Initialized empty neighbor list buffers.
first_at_neightuple_nn = []
secnd_at_neightuple_nn = []
cell_shift_vector_x_n = []
cell_shift_vector_y_n = []
cell_shift_vector_z_n = []
# This is the main neighbor list search. We loop over neighboring bins and
# then construct all possible pairs of atoms between two bins, assuming
# that each bin contains exactly max_natoms_per_bin atoms. We then throw
# out pairs involving pad atoms with atom index -1 below.
binz_xyz, biny_xyz, binx_xyz = torch.meshgrid(
torch.arange(nbins_c[2]),
torch.arange(nbins_c[1]),
torch.arange(nbins_c[0]),
indexing="ij",
)
# The memory layout of binx_xyz, biny_xyz, binz_xyz is such that computing
# the respective bin index leads to a linearly increasing consecutive list.
# The following assert statement succeeds:
# b_b = (binx_xyz + nbins_c[0] * (biny_xyz + nbins_c[1] *
# binz_xyz)).ravel()
# assert (b_b == torch.arange(torch.prod(nbins_c))).all()
# First atoms in pair.
_first_at_neightuple_n = atoms_in_bin_ba[:, atom_pairs_pn[0]]
for dz in range(-int(neigh_search_z.item()), int(neigh_search_z.item()) + 1):
for dy in range(-int(neigh_search_y.item()), int(neigh_search_y.item()) + 1):
for dx in range(
-int(neigh_search_x.item()), int(neigh_search_x.item()) + 1
):
# Bin index of neighboring bin and shift vector.
shiftx_xyz, neighbinx_xyz = torch_divmod(binx_xyz + dx, nbins_c[0])
shifty_xyz, neighbiny_xyz = torch_divmod(biny_xyz + dy, nbins_c[1])
shiftz_xyz, neighbinz_xyz = torch_divmod(binz_xyz + dz, nbins_c[2])
neighbin_b = (
neighbinx_xyz
+ nbins_c[0] * (neighbiny_xyz + nbins_c[1] * neighbinz_xyz)
).ravel()
# Second atom in pair.
_secnd_at_neightuple_n = atoms_in_bin_ba[neighbin_b][
:, atom_pairs_pn[1]
]
# Shift vectors.
# TODO: was np.resize:
# _cell_shift_vector_x_n_np = np.resize(
# shiftx_xyz.reshape(-1, 1).numpy(),
# (int(max_natoms_per_bin.item() ** 2), shiftx_xyz.numel()),
# ).T
# _cell_shift_vector_y_n_np = np.resize(
# shifty_xyz.reshape(-1, 1).numpy(),
# (int(max_natoms_per_bin.item() ** 2), shifty_xyz.numel()),
# ).T
# _cell_shift_vector_z_n_np = np.resize(
# shiftz_xyz.reshape(-1, 1).numpy(),
# (int(max_natoms_per_bin.item() ** 2), shiftz_xyz.numel()),
# ).T
# this basically just tiles shiftx_xyz.reshape(-1, 1) n times
_cell_shift_vector_x_n = shiftx_xyz.reshape(-1, 1).repeat(
(1, int(max_natoms_per_bin.item() ** 2))
)
# assert _cell_shift_vector_x_n.shape == _cell_shift_vector_x_n_np.shape
# assert np.allclose(
# _cell_shift_vector_x_n.numpy(), _cell_shift_vector_x_n_np
# )
_cell_shift_vector_y_n = shifty_xyz.reshape(-1, 1).repeat(
(1, int(max_natoms_per_bin.item() ** 2))
)
# assert _cell_shift_vector_y_n.shape == _cell_shift_vector_y_n_np.shape
# assert np.allclose(
# _cell_shift_vector_y_n.numpy(), _cell_shift_vector_y_n_np
# )
_cell_shift_vector_z_n = shiftz_xyz.reshape(-1, 1).repeat(
(1, int(max_natoms_per_bin.item() ** 2))
)
# assert _cell_shift_vector_z_n.shape == _cell_shift_vector_z_n_np.shape
# assert np.allclose(
# _cell_shift_vector_z_n.numpy(), _cell_shift_vector_z_n_np
# )
# We have created too many pairs because we assumed each bin
# has exactly max_natoms_per_bin atoms. Remove all surperfluous
# pairs. Those are pairs that involve an atom with index -1.
mask = torch.logical_and(
_first_at_neightuple_n != -1, _secnd_at_neightuple_n != -1
)
if mask.sum() > 0:
first_at_neightuple_nn += [_first_at_neightuple_n[mask]]
secnd_at_neightuple_nn += [_secnd_at_neightuple_n[mask]]
cell_shift_vector_x_n += [_cell_shift_vector_x_n[mask]]
cell_shift_vector_y_n += [_cell_shift_vector_y_n[mask]]
cell_shift_vector_z_n += [_cell_shift_vector_z_n[mask]]
# Flatten overall neighbor list.
first_at_neightuple_n = torch.cat(first_at_neightuple_nn)
secnd_at_neightuple_n = torch.cat(secnd_at_neightuple_nn)
cell_shift_vector_n = torch.vstack(
[
torch.cat(cell_shift_vector_x_n),
torch.cat(cell_shift_vector_y_n),
torch.cat(cell_shift_vector_z_n),
]
).T
# Add global cell shift to shift vectors
cell_shift_vector_n += (
cell_shift_ic[first_at_neightuple_n] - cell_shift_ic[secnd_at_neightuple_n]
)
# Remove all self-pairs that do not cross the cell boundary.
if not self_interaction:
m = torch.logical_not(
torch.logical_and(
first_at_neightuple_n == secnd_at_neightuple_n,
(cell_shift_vector_n == 0).all(dim=1),
)
)
first_at_neightuple_n = first_at_neightuple_n[m]
secnd_at_neightuple_n = secnd_at_neightuple_n[m]
cell_shift_vector_n = cell_shift_vector_n[m]
# For nonperiodic directions, remove any bonds that cross the domain
# boundary.
for c in range(3):
if not pbc[c]:
m = cell_shift_vector_n[:, c] == 0
first_at_neightuple_n = first_at_neightuple_n[m]
secnd_at_neightuple_n = secnd_at_neightuple_n[m]
cell_shift_vector_n = cell_shift_vector_n[m]
# Sort neighbor list.
i = torch.argsort(first_at_neightuple_n)
first_at_neightuple_n = first_at_neightuple_n[i]
secnd_at_neightuple_n = secnd_at_neightuple_n[i]
cell_shift_vector_n = cell_shift_vector_n[i]
# Compute distance vectors.
distance_vector_nc = (
positions[secnd_at_neightuple_n]
- positions[first_at_neightuple_n]
+ cell_shift_vector_n.to(cell.dtype).matmul(cell) # TODO! .T?
)
abs_distance_vector_n = torch.sqrt(
torch.sum(distance_vector_nc * distance_vector_nc, dim=1)
)
# We have still created too many pairs. Only keep those with distance
# smaller than max_cutoff.
mask = abs_distance_vector_n < max_cutoff
first_at_neightuple_n = first_at_neightuple_n[mask]
secnd_at_neightuple_n = secnd_at_neightuple_n[mask]
cell_shift_vector_n = cell_shift_vector_n[mask]
distance_vector_nc = distance_vector_nc[mask]
abs_distance_vector_n = abs_distance_vector_n[mask]
# Assemble return tuple.
retvals = []
for q in quantities:
if q == "i":
retvals += [first_at_neightuple_n]
elif q == "j":
retvals += [secnd_at_neightuple_n]
elif q == "D":
retvals += [distance_vector_nc]
elif q == "d":
retvals += [abs_distance_vector_n]
elif q == "S":
retvals += [cell_shift_vector_n]
else:
raise ValueError("Unsupported quantity specified.")
return retvals
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment