Skip to content

Instantly share code, notes, and snippets.

@mjhong0708
Last active March 17, 2023 04:48
Show Gist options
  • Save mjhong0708/db400c8d105ae1bdf76c889c69c3e365 to your computer and use it in GitHub Desktop.
Save mjhong0708/db400c8d105ae1bdf76c889c69c3e365 to your computer and use it in GitHub Desktop.
Datapipes for dealing with ASE Atoms with pyg
import os
from collections.abc import Sequence
import ase.io
import ase.neighborlist
import numpy as np
import torch
import ase.data
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph
from torch.utils.data import IterDataPipe, functional_datapipe
@functional_datapipe("read_ase")
class ASEFileReader(IterDataPipe):
def __init__(self, dp):
self.dp = dp
def __iter__(self):
for item in self.dp:
if isinstance(item, str):
filename, index = item, ":"
elif isinstance(item, Sequence):
if len(item) != 2:
raise ValueError("Input datapipe must yield str or (str, str), but got {}".format(item))
filename, index = item
if not isinstance(filename, (str, os.PathLike)):
raise TypeError("Input datapipe must yield str or os.PathLike, but got {}".format(type(filename)))
for atoms in ase.io.iread(filename, index):
yield atoms
@functional_datapipe("atoms_to_graph")
class AtomsGraphParser(IterDataPipe):
default_float_dtype = torch.get_default_dtype()
def __init__(self, dp):
self.dp = dp
@staticmethod
def _atoms_to_graph(atoms):
elems = torch.tensor(atoms.numbers, dtype=torch.long)
pos = torch.tensor(atoms.positions, dtype=AtomsGraphParser.default_float_dtype)
pos = torch.tensor(atoms.positions, dtype=AtomsGraphParser.default_float_dtype)
if atoms.pbc.any() and not atoms.pbc.all():
raise ValueError("Does not support partial pbc")
pbc = atoms.pbc.all()
if pbc:
cell = torch.tensor(atoms.cell.array, dtype=AtomsGraphParser.default_float_dtype).unsqueeze(0)
else:
cell = torch.zeros((1, 3, 3), dtype=AtomsGraphParser.default_float_dtype)
n_atoms = torch.tensor(len(atoms), dtype=torch.long)
kwargs = {}
try:
kwargs["energy"] = torch.tensor(atoms.get_potential_energy(), dtype=AtomsGraphParser.default_float_dtype)
except RuntimeError:
pass
try:
kwargs["force"] = torch.tensor(atoms.get_forces(), dtype=AtomsGraphParser.default_float_dtype)
except RuntimeError:
pass
batch = torch.zeros_like(elems, dtype=torch.long)
data = Data(
elems=elems,
pos=pos,
cell=cell,
n_atoms=n_atoms,
batch=batch,
pbc=pbc,
**kwargs,
)
return data
def __iter__(self):
for atoms in self.dp:
if not isinstance(atoms, ase.Atoms):
raise TypeError("Input datapipe must yield ase.Atoms, but got {}".format(type(atoms)))
# do something with atoms
yield self._atoms_to_graph(atoms)
@functional_datapipe("build_neighbor_list")
class NeighborListBuilder(IterDataPipe):
def __init__(self, dp, cutoff, self_interaction=False, backend="ase", **kwargs):
self.dp = dp
self.cutoff = cutoff
self.self_interaction = self_interaction
self.backend = backend
self.kwargs = kwargs
def _build_with_ase(self, data):
pbc = np.array([data.pbc] * 3)
pos = data.pos.numpy().astype(np.float64)
cell = data.cell.squeeze().numpy().astype(np.float64)
elems = data.elems.numpy().astype(np.int32)
center_idx, neighbor_idx, offset = ase.neighborlist.primitive_neighbor_list(
"ijS", pbc, cell, pos, self.cutoff, elems, self_interaction=self.self_interaction
)
center_idx = torch.LongTensor(center_idx)
neighbor_idx = torch.LongTensor(neighbor_idx)
offset = torch.as_tensor(offset, dtype=torch.float32)
data.edge_index = torch.stack([neighbor_idx, center_idx], dim=0)
data.edge_shift = offset
return data
def _build_with_torch_geometric(self, data):
data.edge_index = radius_graph(
data.pos,
self.cutoff,
data.batch,
loop=self.self_interaction,
max_num_neighbors=self.kwargs.get("max_num_neighbors", 50),
)
data.edge_shift = torch.zeros((data.edge_index.size(1), 3), dtype=torch.float32)
return data
def __iter__(self):
for data in self.dp:
if not isinstance(data, Data):
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data)))
if data.cell.shape[0] != 1:
raise ValueError("Does not support batched data")
if self.backend == "ase":
transformed = self._build_with_ase(data)
elif self.backend == "torch_geometric":
if data.pbc:
raise NotImplementedError("torch_geometric does not support periodic boundary condition.")
transformed = self._build_with_torch_geometric(data)
else:
raise ValueError("Unknown backend {}".format(self.backend))
yield transformed
@functional_datapipe("standardize_property")
class PropertyStandardizer(IterDataPipe):
def __init__(self, dp, mean, std, target):
self.dp = dp
self.mean = mean
self.std = std
self.target = target
def __iter__(self):
for data in self.dp:
if not isinstance(data, Data):
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data)))
if self.target not in data:
raise ValueError("Property {} not found in data".format(self.target))
data[self.target] = (data[self.target] - self.mean) / self.std
yield data
@functional_datapipe("unstandardize_property")
class PropertyUnStandardizer(IterDataPipe):
def __init__(self, dp, mean, std, target):
self.dp = dp
self.mean = mean
self.std = std
self.target = target
def __iter__(self):
for data in self.dp:
if not isinstance(data, Data):
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data)))
if self.target not in data:
raise ValueError("Property {} not found in data".format(self.target))
data[self.target] = data[self.target] * self.std + self.mean
yield data
@functional_datapipe("subtract_atomref")
class SubtractAtomref(IterDataPipe):
def __init__(self, dp, atomic_energies):
self.dp = dp
self.atomic_energies = atomic_energies
def __iter__(self):
for data in self.dp:
if not isinstance(data, Data):
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data)))
if data.cell.shape[0] != 1:
raise ValueError("Does not support batched data")
elems = data.elems.numpy()
symbols = [ase.data.chemical_symbols[elem] for elem in elems]
for symbol in symbols:
if symbol not in self.atomic_energies:
raise ValueError("Atomic energy for {} not found".format(symbol))
data["energy"] -= self.atomic_energies[symbol]
yield data
@mjhong0708
Copy link
Author

Example

from torchdata.datapipes.iter import IterableWrapper

dp = (
    IterableWrapper(["path/to/data"]) # paths to files
    .zip(IterableWrapper([":10"])) # indices to read (used in ase.io.read)
    .read_ase() # read file with ase.io.read
    .atoms_to_graph() # construct graph as torch_geometric.data.Data
    .build_neighbor_list(5.0, backend="torch_geometric") # construct neighbor graph
    .in_memory_cache() # cache results
    .subtract_atomref({"H": -0.500273, "C": -37.846772, "N": -54.583863, "O": -75.064579})
    .standardize_property(0.0, 1.0, "energy")
)

# As iterator
for data in dp:
    print(data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment