Skip to content

Instantly share code, notes, and snippets.

@sirmarcel
Last active May 19, 2023 14:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sirmarcel/7b7f4ddb617c10db4a26dd7e75f1aea5 to your computer and use it in GitHub Desktop.
Save sirmarcel/7b7f4ddb617c10db4a26dd7e75f1aea5 to your computer and use it in GitHub Desktop.
Toy example for LJ stress variations with SchNetPack
from typing import Dict, Optional, List
import torch
import torch.nn as nn
from ase.neighborlist import neighbor_list
from ase.calculators.lj import LennardJones
from ase.build import bulk
from schnetpack import properties
from schnetpack.data.loader import _atoms_collate_fn
from schnetpack.atomistic import Forces, Strain, PairwiseDistances
from schnetpack.transform import ASENeighborList
from schnetpack.model import AtomisticModel
import schnetpack.nn as snn
class LJ(AtomisticModel):
"""Mirror of ase LJ calculator with unsmooth cutoff"""
def __init__(self, epsilon, sigma, rc, stress_mode="virials", **kwargs):
super().__init__(**kwargs)
self.epsilon = torch.tensor(epsilon)
self.sigma = torch.tensor(sigma)
self.rc = torch.tensor(rc)
self.stress_mode = stress_mode
def forward(self, inputs):
if self.stress_mode == "virials":
inputs = PairwiseDistances()(inputs)
inputs[properties.Rij].requires_grad_()
elif self.stress_mode == "strain_rij":
inputs = PairwiseDistances()(inputs)
inputs = StrainPairwise()(inputs)
elif self.stress_mode == "strain":
inputs = Strain()(inputs)
inputs = PairwiseDistances()(inputs)
elif self.stress_mode == "strain_compute_offsets":
inputs = Strain()(inputs)
inputs[properties.offsets] = torch.mm(
inputs["S"], inputs[properties.cell].squeeze()
)
inputs = PairwiseDistances()(inputs)
inputs[properties.Rij].requires_grad_()
vec_ij = inputs[properties.Rij]
r_ij = torch.norm(vec_ij, dim=1, keepdim=True)
power_6 = torch.pow(self.sigma / r_ij, 6)
power_12 = power_6 * power_6
u_ij = 0.5 * 4 * self.epsilon * (power_12 - power_6)
u_ij -= (
0.5
* 4
* self.epsilon
* (
torch.pow(self.sigma / self.rc, 12)
- (torch.pow(self.sigma / self.rc, 6))
)
)
idx_i = inputs[properties.idx_i]
idx_m = inputs[properties.idx_m]
atomic_numbers = inputs[properties.Z]
n_atoms = atomic_numbers.shape[0]
maxm = int(idx_m[-1]) + 1
u_i = snn.scatter_add(u_ij, idx_i, dim_size=n_atoms) # sum pairs
energy = snn.scatter_add(u_i, idx_m, dim_size=maxm) # sum atoms
inputs[properties.energy] = energy
if self.stress_mode == "virials":
from torch.autograd import grad
du_drij = grad(inputs[properties.energy], inputs[properties.Rij])[0]
rij = inputs[properties.Rij]
idx_i = inputs[properties.idx_i]
idx_m = inputs[properties.idx_m]
atomic_numbers = inputs[properties.Z]
n_atoms = atomic_numbers.shape[0]
maxm = int(idx_m[-1]) + 1
virials = rij.unsqueeze(-2) * du_drij.unsqueeze(-1)
virials = snn.scatter_add(virials, idx_i, dim_size=n_atoms) # sum pairs
virials = snn.scatter_add(virials, idx_m, dim_size=maxm) # sum atoms
cell = inputs[properties.cell]
volume = torch.sum(
cell[:, 0, :] * torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
dim=1,
keepdim=True,
)[:, :, None]
inputs["stress"] = virials / volume
else:
inputs = Forces(calc_stress=True, calc_forces=True)(inputs)
return inputs
class StrainPairwise(nn.Module):
"""
This is required to calculate the stress as a response property.
Adds strain-dependence to relative atomic positions Rij.
"""
def forward(self, inputs: Dict[str, torch.Tensor]):
# will fail for batches > 1
strain = torch.zeros_like(inputs[properties.cell]).squeeze()
strain.requires_grad_()
inputs[properties.strain] = strain
inputs[properties.Rij] = inputs[properties.Rij] + torch.matmul(
inputs[properties.Rij], strain
)
return inputs
def convert(atoms, cutoff):
inputs = {
properties.n_atoms: torch.tensor([atoms.get_global_number_of_atoms()]),
properties.Z: torch.from_numpy(atoms.get_atomic_numbers()),
properties.R: torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32),
properties.cell: torch.from_numpy(atoms.get_cell().array).to(
dtype=torch.float32
),
properties.pbc: torch.from_numpy(atoms.get_pbc()),
}
idx_i, idx_j, S = neighbor_list("ijS", atoms, cutoff, self_interaction=False)
inputs[properties.idx_i] = torch.from_numpy(idx_i)
inputs[properties.idx_j] = torch.from_numpy(idx_j)
S = torch.from_numpy(S).to(dtype=torch.float32)
offsets = torch.mm(S, inputs[properties.cell])
inputs[properties.offsets] = offsets
inputs["S"] = S
inputs[properties.cell] = inputs[properties.cell].unsqueeze(0)
inputs[properties.R].requires_grad_()
inputs = _atoms_collate_fn([inputs])
return inputs
rc = 8.0
sigma = 3.405
epsilon = 0.010325
import numpy as np
np.random.seed(123)
atoms = bulk("Ar", cubic=True) * [4, 4, 4]
# need to get out of equilibrium
strain = np.random.random((3, 3))
atoms.positions += np.einsum("ab,ib->ia", strain, atoms.positions)
variations = {
"strain": LJ(sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="strain"),
"strain_compute_offsets": LJ(
sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="strain_compute_offsets"
),
"strain_rij": LJ(sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="strain_rij"),
"rij": LJ(sigma=sigma, epsilon=epsilon, rc=rc, stress_mode="virials"),
}
for name, variation in variations.items():
print(name)
inputs = convert(atoms, rc)
stress = variation(inputs)["stress"].detach().numpy()
print(stress)
print()
calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, smooth=False)
atoms.calc = calc
print("ase")
print(atoms.get_stress(voigt=False))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment