Skip to content

Instantly share code, notes, and snippets.

@alisterburt
Created February 22, 2024 01:30
Show Gist options
  • Save alisterburt/a7908fe6675b659c2ccd75ba03653417 to your computer and use it in GitHub Desktop.
Save alisterburt/a7908fe6675b659c2ccd75ba03653417 to your computer and use it in GitHub Desktop.
molmap - slow
import numpy as np
import torch
import einops
import mmdf
from libtilt.grids._patch_grid_utils import patch_grid_centers
from libtilt.grids import coordinate_grid
RESOLUTION_ANGSTROMS = 3.4
VOLUME_SIZE = (512, 512, 512)
N_STDS_TO_INCLUDE = 3
N_ATOMS_PER_BATCH = 8096
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
else:
device = torch.device("mps")
df = mmdf.read('4v6x.cif')
xyz = df[['x', 'y', 'z']].to_numpy()
mean = einops.reduce(xyz, pattern='b xyz -> xyz', reduction='mean')
df[['x', 'y', 'z']] = xyz - mean
# get coords, amplitudes and sdevs
atom_zyx = torch.tensor(df[['z', 'y', 'x']].to_numpy() + np.array(VOLUME_SIZE) / 2, dtype=torch.float32).to(device)
b = len(atom_zyx)
atom_stdev = torch.ones(size=(b,), dtype=torch.float32, device=device) * RESOLUTION_ANGSTROMS
atom_amplitude = torch.tensor(df['atomic_number'].to_numpy()).to(device)
volume = torch.empty(size=(512, 512, 512), dtype=torch.float32, device=device)
centers = patch_grid_centers(image_shape=(512, 512, 512), patch_shape=(32, 32, 32), patch_step=(32, 32, 32), device=device)
centers = einops.rearrange(centers, pattern='pd ph pw zyx -> (pd ph pw) zyx')
for subvolume_idx, subvolume_center in enumerate(centers):
print(subvolume_idx, '/', len(centers))
grid = coordinate_grid(image_shape=(32, 32, 32), center=subvolume_center, device=device)
# check distances from center
difference = atom_zyx - subvolume_center
distance_from_subvolume_center = (difference ** 2).sum(dim=-1) ** 0.5
max_distances = N_STDS_TO_INCLUDE * atom_stdev + np.linalg.norm([32, 32, 32])
idx = torch.abs(distance_from_subvolume_center) < max_distances
idx = einops.rearrange(torch.nonzero(idx), pattern='b 1 -> b')
n_atoms = len(idx)
num_batches = (n_atoms + N_ATOMS_PER_BATCH - 1) // N_ATOMS_PER_BATCH
out = torch.zeros(size=(32, 32, 32), dtype=torch.float32, device=device)
for i in range(num_batches):
# get atoms in the batch
start_index = i * N_ATOMS_PER_BATCH
end_index = min(n_atoms, start_index + N_ATOMS_PER_BATCH)
_idx = idx[start_index:end_index]
# get zyx, stdev and amplitude of each gaussian in the batch
_atom_zyx = einops.rearrange(atom_zyx[_idx], 'b zyx -> b 1 1 1 zyx')
_atom_stdev = einops.rearrange(atom_stdev[_idx], 'b -> b 1 1 1 1')
_atom_amplitude = einops.rearrange(atom_amplitude[_idx], 'b -> b 1 1 1 1')
# evaluate gaussian for each atom
_grid = (grid ** 2) / (2 * _atom_stdev ** 2)
_grid = einops.reduce(_grid, pattern='... d h w zyx -> ... d h w', reduction='sum')
_grid = torch.exp(-1 * _grid)
# sum the gaussians
_grid = einops.reduce(_grid, pattern='... d h w -> d h w', reduction='sum')
out = out + _grid
print(i, '/', num_batches)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment