Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save alisterburt/3e66afe30e83ae4a1774f0ff623a407e to your computer and use it in GitHub Desktop.
Save alisterburt/3e66afe30e83ae4a1774f0ff623a407e to your computer and use it in GitHub Desktop.
3D correlations from multiple 2D correlations for Will Wan
import einops
import numpy as np
import torch
import mrcfile
import napari
from torch_fourier_slice import project_3d_to_2d
from torch_image_lerp import sample_image_2d
from torch_grid_utils import coordinate_grid
from scipy.spatial.transform import Rotation as R
# read in volume file
volume_file = '/Users/burta2/data/4v6x_bin4.mrc'
reference_volume = torch.tensor(mrcfile.read(volume_file))
# make projections of a shifted copy of the volume
# this is our 'experimental particle'
# do the whole expriment a bunch of times to compare results
for i in range(100):
experimental_particle_3d = torch.zeros_like(reference_volume)
true_shift = np.random.randint(low=1, high=10, size=(3, ))
while np.linalg.norm(true_shift) > 30:
true_shift = np.random.randint(low=1, high=10, size=(3, ))
print(f'true shift: {true_shift}')
sx, sy, sz = true_shift
experimental_particle_3d[sx:, sy:, sz:] = reference_volume[:-sx, :-sy, :-sz]
# vis
# viewer = napari.Viewer(ndisplay=3)
# viewer.add_image(reference_volume.numpy())
# viewer.add_image(experimental_particle_3d.numpy())
# napari.run()
# simulate a -60 to +60 tilt series of the reference
# rotation matrices rotate internal coordinate system (intrinsic rather than extrinsic)
tilt_angles = np.linspace(-60, 60, 41, endpoint=True)
rotation_matrices = R.from_euler(angles=tilt_angles, seq='y', degrees=True).as_matrix()
rotation_matrices = torch.tensor(rotation_matrices).float()
reference_projections = project_3d_to_2d(reference_volume, rotation_matrices=rotation_matrices)
# vis
# viewer = napari.Viewer()
# viewer.add_image(projections.numpy())
# napari.run()
# simulate experimental projections from experimental particle
# (these would be extracted from TS in actual processing)
experimental_projections = project_3d_to_2d(experimental_particle_3d, rotation_matrices=rotation_matrices)
# correlate 2D images
reference_projections_centered = torch.fft.fftshift(reference_projections, dim=(-2, -1))
reference_projections_dft = torch.fft.rfftn(reference_projections_centered, dim=(-2, -1))
experimental_projections_dft = torch.fft.rfftn(experimental_projections, dim=(-2, -1))
correlations_2d_dft = reference_projections_dft * experimental_projections_dft
correlations_2d = torch.fft.irfftn(correlations_2d_dft, dim=(-2, -1))
image_center_2d = torch.tensor(reference_projections.shape[-2:]) // 2
# vis
# viewer = napari.Viewer()
# viewer.add_image(reference_projections.numpy())
# viewer.add_image(experimental_projections.numpy())
# viewer.add_image(correlations_2d.numpy())
# viewer.add_points(image_center_2d.numpy(), face_color='red', size=5)
# napari.run()
# grid of possible 3D shift values
shift_grid = coordinate_grid(
image_shape=reference_volume.shape,
center=np.array(reference_volume.shape) // 2,
)
# define a restricted region of valid correlations, here a sphere of radius 5
correlation_mask_3d = torch.linalg.norm(shift_grid, dim=-1) <= 35 # (d, d, d)
# get the 3D xyz shifts for each point within the correlation mask
valid_shifts_zyx = shift_grid[correlation_mask_3d, :] # (b, 3) array of zyx shifts
valid_shifts_xyz = torch.flip(valid_shifts_zyx, dims=(-1, ))
# project these 3D shifts into 2D
# shifts are (nshifts, 3)
# rotation matrices are (ntilts, 3, 3)
# result will be an
# - (ntilts, nshifts, 2) array of xy shifts
valid_shifts_xyz = einops.rearrange(valid_shifts_xyz, 'nshifts xyz -> nshifts xyz 1')
rotation_matrices_extrinsic = torch.linalg.inv(rotation_matrices)
rotation_matrices_extrinsic = einops.rearrange(rotation_matrices_extrinsic, 'ntilts i j -> ntilts 1 i j')
projection_matrices = rotation_matrices_extrinsic[..., :2, :]
projected_shifts_xy = projection_matrices @ valid_shifts_xyz # (ntilts, nshifts, 2, 1)
projected_shifts_xy = einops.rearrange(projected_shifts_xy, 'ntilts nshifts xy 1 -> ntilts nshifts xy')
# after projecting the shifts, let's sample the 2D correlation functions at projected shift positions
# remembering to fix xy -> yx for image sampling
sampling_positions_yx_2d = torch.flip(projected_shifts_xy, dims=(-1,)) + image_center_2d
# visualise sample positions relative to image center
# viewer = napari.Viewer()
# points_napari = einops.rearrange(sampling_positions_yx_2d[:, 0:5], 'ntilts nshifts yx -> (ntilts nshifts) yx')
# viewer.add_points(points_napari, size=0.1, face_color='cornflowerblue')
# viewer.add_points(image_center_2d.numpy(), size=0.5, face_color='red')
# napari.run()
correlation_samples = torch.stack(
[
sample_image_2d(tilt_correlation_image, coordinates=per_tilt_samples)
for tilt_correlation_image, per_tilt_samples
in zip(correlations_2d, sampling_positions_yx_2d)
]
) # (ntilts, nshifts) array of per-tilt, per-shift correlation values
# now we can calculate the 3D correlation value by doing a sum of the 2D correlations
weights = torch.ones(size=(experimental_projections.shape[0], 1))
weighted_correlations = weights * correlation_samples
per_shift_correlations = einops.reduce(
weighted_correlations, 'ntilts nshifts -> nshifts', reduction='sum'
)
# which shift has the max correlation?
best_shift_idx = torch.argmax(per_shift_correlations)
best_shift = np.array(valid_shifts_zyx[best_shift_idx])
print(f'best shift: {best_shift}')
print(f'difference: {best_shift - true_shift}')
print('\n')
# something appears systematically off, tendency for z, y shifts to be 1 away from ideal
# cc peaks are very diffuse same as we saw in Cambridge - I'm not sure why but maybe you have an idea?
# otherwise the core of it is working!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment