Last active
March 17, 2023 04:48
-
-
Save mjhong0708/db400c8d105ae1bdf76c889c69c3e365 to your computer and use it in GitHub Desktop.
Datapipes for dealing with ASE Atoms with pyg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from collections.abc import Sequence | |
import ase.io | |
import ase.neighborlist | |
import numpy as np | |
import torch | |
import ase.data | |
from torch_geometric.data import Data | |
from torch_geometric.nn import radius_graph | |
from torch.utils.data import IterDataPipe, functional_datapipe | |
@functional_datapipe("read_ase") | |
class ASEFileReader(IterDataPipe): | |
def __init__(self, dp): | |
self.dp = dp | |
def __iter__(self): | |
for item in self.dp: | |
if isinstance(item, str): | |
filename, index = item, ":" | |
elif isinstance(item, Sequence): | |
if len(item) != 2: | |
raise ValueError("Input datapipe must yield str or (str, str), but got {}".format(item)) | |
filename, index = item | |
if not isinstance(filename, (str, os.PathLike)): | |
raise TypeError("Input datapipe must yield str or os.PathLike, but got {}".format(type(filename))) | |
for atoms in ase.io.iread(filename, index): | |
yield atoms | |
@functional_datapipe("atoms_to_graph") | |
class AtomsGraphParser(IterDataPipe): | |
default_float_dtype = torch.get_default_dtype() | |
def __init__(self, dp): | |
self.dp = dp | |
@staticmethod | |
def _atoms_to_graph(atoms): | |
elems = torch.tensor(atoms.numbers, dtype=torch.long) | |
pos = torch.tensor(atoms.positions, dtype=AtomsGraphParser.default_float_dtype) | |
pos = torch.tensor(atoms.positions, dtype=AtomsGraphParser.default_float_dtype) | |
if atoms.pbc.any() and not atoms.pbc.all(): | |
raise ValueError("Does not support partial pbc") | |
pbc = atoms.pbc.all() | |
if pbc: | |
cell = torch.tensor(atoms.cell.array, dtype=AtomsGraphParser.default_float_dtype).unsqueeze(0) | |
else: | |
cell = torch.zeros((1, 3, 3), dtype=AtomsGraphParser.default_float_dtype) | |
n_atoms = torch.tensor(len(atoms), dtype=torch.long) | |
kwargs = {} | |
try: | |
kwargs["energy"] = torch.tensor(atoms.get_potential_energy(), dtype=AtomsGraphParser.default_float_dtype) | |
except RuntimeError: | |
pass | |
try: | |
kwargs["force"] = torch.tensor(atoms.get_forces(), dtype=AtomsGraphParser.default_float_dtype) | |
except RuntimeError: | |
pass | |
batch = torch.zeros_like(elems, dtype=torch.long) | |
data = Data( | |
elems=elems, | |
pos=pos, | |
cell=cell, | |
n_atoms=n_atoms, | |
batch=batch, | |
pbc=pbc, | |
**kwargs, | |
) | |
return data | |
def __iter__(self): | |
for atoms in self.dp: | |
if not isinstance(atoms, ase.Atoms): | |
raise TypeError("Input datapipe must yield ase.Atoms, but got {}".format(type(atoms))) | |
# do something with atoms | |
yield self._atoms_to_graph(atoms) | |
@functional_datapipe("build_neighbor_list") | |
class NeighborListBuilder(IterDataPipe): | |
def __init__(self, dp, cutoff, self_interaction=False, backend="ase", **kwargs): | |
self.dp = dp | |
self.cutoff = cutoff | |
self.self_interaction = self_interaction | |
self.backend = backend | |
self.kwargs = kwargs | |
def _build_with_ase(self, data): | |
pbc = np.array([data.pbc] * 3) | |
pos = data.pos.numpy().astype(np.float64) | |
cell = data.cell.squeeze().numpy().astype(np.float64) | |
elems = data.elems.numpy().astype(np.int32) | |
center_idx, neighbor_idx, offset = ase.neighborlist.primitive_neighbor_list( | |
"ijS", pbc, cell, pos, self.cutoff, elems, self_interaction=self.self_interaction | |
) | |
center_idx = torch.LongTensor(center_idx) | |
neighbor_idx = torch.LongTensor(neighbor_idx) | |
offset = torch.as_tensor(offset, dtype=torch.float32) | |
data.edge_index = torch.stack([neighbor_idx, center_idx], dim=0) | |
data.edge_shift = offset | |
return data | |
def _build_with_torch_geometric(self, data): | |
data.edge_index = radius_graph( | |
data.pos, | |
self.cutoff, | |
data.batch, | |
loop=self.self_interaction, | |
max_num_neighbors=self.kwargs.get("max_num_neighbors", 50), | |
) | |
data.edge_shift = torch.zeros((data.edge_index.size(1), 3), dtype=torch.float32) | |
return data | |
def __iter__(self): | |
for data in self.dp: | |
if not isinstance(data, Data): | |
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data))) | |
if data.cell.shape[0] != 1: | |
raise ValueError("Does not support batched data") | |
if self.backend == "ase": | |
transformed = self._build_with_ase(data) | |
elif self.backend == "torch_geometric": | |
if data.pbc: | |
raise NotImplementedError("torch_geometric does not support periodic boundary condition.") | |
transformed = self._build_with_torch_geometric(data) | |
else: | |
raise ValueError("Unknown backend {}".format(self.backend)) | |
yield transformed | |
@functional_datapipe("standardize_property") | |
class PropertyStandardizer(IterDataPipe): | |
def __init__(self, dp, mean, std, target): | |
self.dp = dp | |
self.mean = mean | |
self.std = std | |
self.target = target | |
def __iter__(self): | |
for data in self.dp: | |
if not isinstance(data, Data): | |
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data))) | |
if self.target not in data: | |
raise ValueError("Property {} not found in data".format(self.target)) | |
data[self.target] = (data[self.target] - self.mean) / self.std | |
yield data | |
@functional_datapipe("unstandardize_property") | |
class PropertyUnStandardizer(IterDataPipe): | |
def __init__(self, dp, mean, std, target): | |
self.dp = dp | |
self.mean = mean | |
self.std = std | |
self.target = target | |
def __iter__(self): | |
for data in self.dp: | |
if not isinstance(data, Data): | |
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data))) | |
if self.target not in data: | |
raise ValueError("Property {} not found in data".format(self.target)) | |
data[self.target] = data[self.target] * self.std + self.mean | |
yield data | |
@functional_datapipe("subtract_atomref") | |
class SubtractAtomref(IterDataPipe): | |
def __init__(self, dp, atomic_energies): | |
self.dp = dp | |
self.atomic_energies = atomic_energies | |
def __iter__(self): | |
for data in self.dp: | |
if not isinstance(data, Data): | |
raise TypeError("Input datapipe must yield torch_geometric.data.Data, but got {}".format(type(data))) | |
if data.cell.shape[0] != 1: | |
raise ValueError("Does not support batched data") | |
elems = data.elems.numpy() | |
symbols = [ase.data.chemical_symbols[elem] for elem in elems] | |
for symbol in symbols: | |
if symbol not in self.atomic_energies: | |
raise ValueError("Atomic energy for {} not found".format(symbol)) | |
data["energy"] -= self.atomic_energies[symbol] | |
yield data |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example