Skip to content

Instantly share code, notes, and snippets.

Last active March 28, 2023 19:32
Show Gist options
  • Save alisterburt/88133c823c70ef6f7375ff7db15320ad to your computer and use it in GitHub Desktop.
Save alisterburt/88133c823c70ef6f7375ff7db15320ad to your computer and use it in GitHub Desktop.
Rotational average 3D for Pranav
from pathlib import Path
from typing import List, Sequence, Tuple
import einops
import mrcfile
import numpy as np
import torch
import typer
cli = typer.Typer(name='raps_3d', no_args_is_help=True, add_completion=False)
def rfft_shape_from_signal_shape(input_shape: Sequence[int]) -> Tuple[int]:
"""Get the output shape of an rfft on an input with input_shape."""
rfft_shape = list(input_shape)
rfft_shape[-1] = int((rfft_shape[-1] / 2) + 1)
return tuple(rfft_shape)
def fft_center(
grid_shape: Tuple[int, ...], fftshifted: bool, rfft: bool
) -> torch.Tensor:
"""Return the indices of the fftshifted DFT center."""
fft_center = torch.zeros(size=(len(grid_shape),))
grid_shape = torch.as_tensor(grid_shape).float()
if rfft is True:
grid_shape = torch.tensor(rfft_shape_from_signal_shape(grid_shape))
if fftshifted is True:
fft_center = torch.divide(grid_shape, 2, rounding_mode='floor')
if rfft is True:
fft_center[-1] = 0
return fft_center
def _indices_centered_on_dc_for_shifted_rfft(
rfft_shape: Sequence[int]
) -> torch.Tensor:
rfft_shape = torch.tensor(rfft_shape)
rfftn_dc_idx = torch.div(rfft_shape, 2, rounding_mode='floor')
rfftn_dc_idx[-1] = 0
rfft_indices = torch.tensor(np.indices(rfft_shape)) # (c, (d), h, w)
rfft_indices = einops.rearrange(rfft_indices, 'c ... -> ... c')
return rfft_indices - rfftn_dc_idx
def _distance_from_dc_for_shifted_rfft(rfft_shape: Sequence[int]) -> torch.Tensor:
centered_indices = _indices_centered_on_dc_for_shifted_rfft(rfft_shape)
return einops.reduce(centered_indices ** 2, '... c -> ...', reduction='sum') ** 0.5
def _indices_centered_on_dc_for_shifted_dft(
dft_shape: Sequence[int], rfft: bool
) -> torch.Tensor:
if rfft is True:
return _indices_centered_on_dc_for_shifted_rfft(dft_shape)
dft_indices = torch.tensor(np.indices(dft_shape)).float()
dft_indices = einops.rearrange(dft_indices, 'c ... -> ... c')
dc_idx = fft_center(dft_shape, fftshifted=True, rfft=False)
return dft_indices - dc_idx
def _distance_from_dc_for_shifted_dft(
dft_shape: Sequence[int], rfft: bool
) -> torch.Tensor:
idx = _indices_centered_on_dc_for_shifted_dft(dft_shape, rfft=rfft)
return einops.reduce(idx ** 2, '... c -> ...', reduction='sum') ** 0.5
def indices_centered_on_dc_for_dft(
dft_shape: Sequence[int], rfft: bool, fftshifted: bool
) -> torch.Tensor:
dft_indices = _indices_centered_on_dc_for_shifted_dft(dft_shape, rfft=rfft)
dft_indices = einops.rearrange(dft_indices, '... c -> c ...')
if fftshifted is False:
dims_to_shift = tuple(torch.arange(start=-1 * len(dft_shape), end=0, step=1))
dims_to_shift = dims_to_shift[:-1] if rfft is True else dims_to_shift
dft_indices = torch.fft.ifftshift(dft_indices, dim=dims_to_shift)
return einops.rearrange(dft_indices, 'c ... -> ... c')
def distance_from_dc_for_dft(
dft_shape: Sequence[int], rfft: bool, fftshifted: bool
) -> torch.Tensor:
idx = indices_centered_on_dc_for_dft(dft_shape, rfft=rfft, fftshifted=fftshifted)
return einops.reduce(idx ** 2, '... c -> ...', reduction='sum') ** 0.5
def _find_shell_indices_1d(
distances: torch.Tensor, n_shells: int
) -> List[torch.Tensor]:
"""Find indices into a vector of distances for shells 1 unit apart."""
sorted, sort_idx = torch.sort(distances, descending=False)
split_points = torch.linspace(start=0.5, end=n_shells - 0.5, steps=n_shells)
split_idx = torch.searchsorted(sorted, split_points)
return torch.tensor_split(sort_idx, split_idx)[:-1]
def _split_into_shells_3d(
image: torch.Tensor, n_shells: int, rfft: bool = False, fftshifted: bool = True
) -> List[torch.Tensor]:
d, h, w = image.shape[-3:]
distances = distance_from_dc_for_dft(
dft_shape=(d, h, w), rfft=rfft, fftshifted=fftshifted
distances = einops.rearrange(distances, 'd h w -> (d h w)')
per_shell_indices = _find_shell_indices_1d(distances, n_shells=n_shells)
image = einops.rearrange(image, '... d h w -> ... (d h w)')
shells = [
image[..., shell_idx]
for shell_idx in per_shell_indices
return shells
def rotational_average_3d(
image: torch.Tensor, rfft: bool = False, fftshifted: bool = True
) -> torch.Tensor:
n_shells = image.shape[-3] // 2
shells = _split_into_shells_3d(
image, n_shells=n_shells, rfft=rfft, fftshifted=fftshifted
means = [
einops.reduce(shell, '... shell -> ...', reduction='mean')
for shell in shells
return einops.rearrange(means, 'shells ... -> ... shells')
def main(
volume_file: Path = typer.Option(..., '--volume-file', '-i'),
output_file: Path = typer.Option(..., '--output-file', '-o', help='text file'),
volume = torch.tensor(
with, permissive=True, header_only=True) as mrc:
apix = float(mrc.voxel_size.x)
dft = torch.fft.fftn(volume, dim=(-3, -2, -1)).abs().square()
raps = rotational_average_3d(dft, rfft=True, fftshifted=False).numpy()
spectral_idx = np.arange(len(raps))
nyquist_idx = (volume.shape[-1] // 2) - 1
fraction_of_nyquist = spectral_idx / nyquist_idx
freqs = fraction_of_nyquist * (1 / (2 * apix))
del volume
np.savetxt(output_file, raps)
typer.echo(f'file with data saved to {output_file}')
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
title='log(power) vs spatial frequency',
xlabel='spatial frequency (1/Å)',
ax.xaxis.set_major_formatter(lambda x, _: f'1/{1 / x:.2f}')
ax.plot(freqs, np.log(raps))
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment