-
-
Save sirmarcel/7b7f4ddb617c10db4a26dd7e75f1aea5 to your computer and use it in GitHub Desktop.
Toy example for LJ stress variations with SchNetPack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Optional, List | |
import torch | |
import torch.nn as nn | |
from ase.neighborlist import neighbor_list | |
from ase.calculators.lj import LennardJones | |
from ase.build import bulk | |
from schnetpack import properties | |
from schnetpack.data.loader import _atoms_collate_fn | |
from schnetpack.atomistic import Forces, Strain, PairwiseDistances | |
from schnetpack.transform import ASENeighborList | |
from schnetpack.model import AtomisticModel | |
import schnetpack.nn as snn | |
class LJ(AtomisticModel): | |
"""Mirror of ase LJ calculator with unsmooth cutoff""" | |
def __init__(self, epsilon, sigma, rc, stress_mode="virials", **kwargs): | |
super().__init__(**kwargs) | |
self.epsilon = torch.tensor(epsilon) | |
self.sigma = torch.tensor(sigma) | |
self.rc = torch.tensor(rc) | |
self.stress_mode = stress_mode | |
def forward(self, inputs): | |
if self.stress_mode == "virials": | |
inputs = PairwiseDistances()(inputs) | |
inputs[properties.Rij].requires_grad_() | |
elif self.stress_mode == "strain_rij": | |
inputs = PairwiseDistances()(inputs) | |
inputs = StrainPairwise()(inputs) | |
elif self.stress_mode == "strain": | |
inputs = Strain()(inputs) | |
inputs = PairwiseDistances()(inputs) | |
elif self.stress_mode == "strain_compute_offsets": | |
inputs = Strain()(inputs) | |
inputs[properties.offsets] = torch.mm( | |
inputs["S"], inputs[properties.cell].squeeze() | |
) | |
inputs = PairwiseDistances()(inputs) | |
inputs[properties.Rij].requires_grad_() | |
vec_ij = inputs[properties.Rij] | |
r_ij = torch.norm(vec_ij, dim=1, keepdim=True) | |
power_6 = torch.pow(self.sigma / r_ij, 6) | |
power_12 = power_6 * power_6 | |
u_ij = 0.5 * 4 * self.epsilon * (power_12 - power_6) | |
u_ij -= ( | |
0.5 | |
* 4 | |
* self.epsilon | |
* ( | |
torch.pow(self.sigma / self.rc, 12) | |
- (torch.pow(self.sigma / self.rc, 6)) | |
) | |
) | |
idx_i = inputs[properties.idx_i] | |
idx_m = inputs[properties.idx_m] | |
atomic_numbers = inputs[properties.Z] | |
n_atoms = atomic_numbers.shape[0] | |
maxm = int(idx_m[-1]) + 1 | |
u_i = snn.scatter_add(u_ij, idx_i, dim_size=n_atoms) # sum pairs | |
energy = snn.scatter_add(u_i, idx_m, dim_size=maxm) # sum atoms | |
inputs[properties.energy] = energy | |
if self.stress_mode == "virials": | |
from torch.autograd import grad | |
du_drij = grad(inputs[properties.energy], inputs[properties.Rij])[0] | |
rij = inputs[properties.Rij] | |
idx_i = inputs[properties.idx_i] | |
idx_m = inputs[properties.idx_m] | |
atomic_numbers = inputs[properties.Z] | |
n_atoms = atomic_numbers.shape[0] | |
maxm = int(idx_m[-1]) + 1 | |
virials = rij.unsqueeze(-2) * du_drij.unsqueeze(-1) | |
virials = snn.scatter_add(virials, idx_i, dim_size=n_atoms) # sum pairs | |
virials = snn.scatter_add(virials, idx_m, dim_size=maxm) # sum atoms | |
cell = inputs[properties.cell] | |
volume = torch.sum( | |
cell[:, 0, :] * torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), | |
dim=1, | |
keepdim=True, | |
)[:, :, None] | |
inputs["stress"] = virials / volume | |
else: | |
inputs = Forces(calc_stress=True, calc_forces=True)(inputs) | |
return inputs | |
class StrainPairwise(nn.Module): | |
""" | |
This is required to calculate the stress as a response property. | |
Adds strain-dependence to relative atomic positions Rij. | |
""" | |
def forward(self, inputs: Dict[str, torch.Tensor]): | |
# will fail for batches > 1 | |
strain = torch.zeros_like(inputs[properties.cell]).squeeze() | |
strain.requires_grad_() | |
inputs[properties.strain] = strain | |
inputs[properties.Rij] = inputs[properties.Rij] + torch.matmul( | |
inputs[properties.Rij], strain | |
) | |
return inputs | |
def convert(atoms, cutoff): | |
inputs = { | |
properties.n_atoms: torch.tensor([atoms.get_global_number_of_atoms()]), | |
properties.Z: torch.from_numpy(atoms.get_atomic_numbers()), | |
properties.R: torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32), | |
properties.cell: torch.from_numpy(atoms.get_cell().array).to( | |
dtype=torch.float32 | |
), | |
properties.pbc: torch.from_numpy(atoms.get_pbc()), | |
} | |
idx_i, idx_j, S = neighbor_list("ijS", atoms, cutoff, self_interaction=False) | |
inputs[properties.idx_i] = torch.from_numpy(idx_i) | |
inputs[properties.idx_j] = torch.from_numpy(idx_j) | |
S = torch.from_numpy(S).to(dtype=torch.float32) | |
offsets = torch.mm(S, inputs[properties.cell]) | |
inputs[properties.offsets] = offsets | |
inputs["S"] = S | |
inputs[properties.cell] = inputs[properties.cell].unsqueeze(0) | |
inputs[properties.R].requires_grad_() | |
inputs = _atoms_collate_fn([inputs]) | |
return inputs | |
rc = 8.0 | |
sigma = 3.405 | |
epsilon = 0.010325 | |
import numpy as np | |
np.random.seed(123) | |
atoms = bulk("Ar", cubic=True) * [4, 4, 4] | |
# need to get out of equilibrium | |
strain = np.random.random((3, 3)) | |
atoms.positions += np.einsum("ab,ib->ia", strain, atoms.positions) | |
variations = { | |
"strain": LJ(sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="strain"), | |
"strain_compute_offsets": LJ( | |
sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="strain_compute_offsets" | |
), | |
"strain_rij": LJ(sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="strain_rij"), | |
"rij": LJ(sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="virials"), | |
} | |
for name, variation in variations.items(): | |
print(name) | |
inputs = convert(atoms, rc) | |
stress = variation(inputs)["stress"].detach().numpy() | |
print(stress) | |
print() | |
calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, smooth=False) | |
atoms.calc = calc | |
print("ase") | |
print(atoms.get_stress(voigt=False)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment