Created
February 22, 2024 01:30
-
-
Save alisterburt/a7908fe6675b659c2ccd75ba03653417 to your computer and use it in GitHub Desktop.
molmap - slow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import einops | |
import mmdf | |
from libtilt.grids._patch_grid_utils import patch_grid_centers | |
from libtilt.grids import coordinate_grid | |
RESOLUTION_ANGSTROMS = 3.4 | |
VOLUME_SIZE = (512, 512, 512) | |
N_STDS_TO_INCLUDE = 3 | |
N_ATOMS_PER_BATCH = 8096 | |
if not torch.backends.mps.is_available(): | |
if not torch.backends.mps.is_built(): | |
print("MPS not available because the current PyTorch install was not " | |
"built with MPS enabled.") | |
else: | |
print("MPS not available because the current MacOS version is not 12.3+ " | |
"and/or you do not have an MPS-enabled device on this machine.") | |
else: | |
device = torch.device("mps") | |
df = mmdf.read('4v6x.cif') | |
xyz = df[['x', 'y', 'z']].to_numpy() | |
mean = einops.reduce(xyz, pattern='b xyz -> xyz', reduction='mean') | |
df[['x', 'y', 'z']] = xyz - mean | |
# get coords, amplitudes and sdevs | |
atom_zyx = torch.tensor(df[['z', 'y', 'x']].to_numpy() + np.array(VOLUME_SIZE) / 2, dtype=torch.float32).to(device) | |
b = len(atom_zyx) | |
atom_stdev = torch.ones(size=(b,), dtype=torch.float32, device=device) * RESOLUTION_ANGSTROMS | |
atom_amplitude = torch.tensor(df['atomic_number'].to_numpy()).to(device) | |
volume = torch.empty(size=(512, 512, 512), dtype=torch.float32, device=device) | |
centers = patch_grid_centers(image_shape=(512, 512, 512), patch_shape=(32, 32, 32), patch_step=(32, 32, 32), device=device) | |
centers = einops.rearrange(centers, pattern='pd ph pw zyx -> (pd ph pw) zyx') | |
for subvolume_idx, subvolume_center in enumerate(centers): | |
print(subvolume_idx, '/', len(centers)) | |
grid = coordinate_grid(image_shape=(32, 32, 32), center=subvolume_center, device=device) | |
# check distances from center | |
difference = atom_zyx - subvolume_center | |
distance_from_subvolume_center = (difference ** 2).sum(dim=-1) ** 0.5 | |
max_distances = N_STDS_TO_INCLUDE * atom_stdev + np.linalg.norm([32, 32, 32]) | |
idx = torch.abs(distance_from_subvolume_center) < max_distances | |
idx = einops.rearrange(torch.nonzero(idx), pattern='b 1 -> b') | |
n_atoms = len(idx) | |
num_batches = (n_atoms + N_ATOMS_PER_BATCH - 1) // N_ATOMS_PER_BATCH | |
out = torch.zeros(size=(32, 32, 32), dtype=torch.float32, device=device) | |
for i in range(num_batches): | |
# get atoms in the batch | |
start_index = i * N_ATOMS_PER_BATCH | |
end_index = min(n_atoms, start_index + N_ATOMS_PER_BATCH) | |
_idx = idx[start_index:end_index] | |
# get zyx, stdev and amplitude of each gaussian in the batch | |
_atom_zyx = einops.rearrange(atom_zyx[_idx], 'b zyx -> b 1 1 1 zyx') | |
_atom_stdev = einops.rearrange(atom_stdev[_idx], 'b -> b 1 1 1 1') | |
_atom_amplitude = einops.rearrange(atom_amplitude[_idx], 'b -> b 1 1 1 1') | |
# evaluate gaussian for each atom | |
_grid = (grid ** 2) / (2 * _atom_stdev ** 2) | |
_grid = einops.reduce(_grid, pattern='... d h w zyx -> ... d h w', reduction='sum') | |
_grid = torch.exp(-1 * _grid) | |
# sum the gaussians | |
_grid = einops.reduce(_grid, pattern='... d h w -> d h w', reduction='sum') | |
out = out + _grid | |
print(i, '/', num_batches) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment