Skip to content

Instantly share code, notes, and snippets.

@alisterburt
Created July 23, 2024 04:24
Show Gist options
  • Save alisterburt/c0608d527fa5e423731cf43ee76535c7 to your computer and use it in GitHub Desktop.
Save alisterburt/c0608d527fa5e423731cf43ee76535c7 to your computer and use it in GitHub Desktop.
spencer dask
from pathlib import Path
from typing import Sequence
import dask
import mrcfile
import numpy as np
import torch
import pandas as pd
import dask.array as da
# construct alignment data
pretend_alignment_data = {
'subtomo_path': sorted(list(Path('data').glob('*.mrc'))),
'x': np.random.uniform(low=0, high=512, size=(100,)),
'y': np.random.uniform(low=0, high=512, size=(100,)),
'z': np.random.uniform(low=0, high=256, size=(100,)),
'dx': np.random.uniform(low=-5, high=5, size=(100,)),
'dy': np.random.uniform(low=-5, high=5, size=(100,)),
'dz': np.random.uniform(low=-5, high=5, size=(100,)),
'rot': np.random.uniform(low=-np.pi, high=np.pi, size=(100,)),
'tilt': np.random.uniform(low=-np.pi, high=np.pi, size=(100,)),
'psi': np.random.uniform(low=-np.pi, high=np.pi, size=(100,)),
}
df = pd.DataFrame(pretend_alignment_data)
print(df.head(5))
# construct dask array containing image data
@dask.delayed
def lazy_imread(file: Path) -> np.ndarray:
return mrcfile.read(file)
lazy_arrays = [lazy_imread(file) for file in df['subtomo_path']]
lazy_arrays = [
da.from_delayed(array, shape=(32, 32, 32), dtype=np.float16)
for array in lazy_arrays
]
subtomos = da.stack(lazy_arrays, axis=0) # (100, 32, 32, 32)
# write a function which preprocesses a single subtomo
def process_image(subvolume, block_id: Sequence[int]) -> np.ndarray:
# get idx from block_id which looks like (idx, 0, 0, 0)
idx = block_id[0]
# get metadata for particle
metadata = df.iloc[idx]
print(f'\nparticle {idx} metadata:\n{metadata}\n')
# do some image processing..
# stick it on the gpu and use all your fancy torch pixel pushing skills
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
subvolume = torch.tensor(subvolume).to(dtype=torch.float32, device=device)
return np.array(subvolume)
# map this function over the lazy array
processed = da.map_blocks(process_image, subtomos, dtype=np.float16, block_id=True)
processed.compute()
print(processed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment