Skip to content

Instantly share code, notes, and snippets.

@kklemon
Last active February 1, 2024 08:50
Show Gist options
  • Save kklemon/aeacecd7d72aa8398507d1f03c76654f to your computer and use it in GitHub Desktop.
Save kklemon/aeacecd7d72aa8398507d1f03c76654f to your computer and use it in GitHub Desktop.
Exemplary showcase implementation of MRI sequence handling with TensorDicts and memory mapped tensors
import torch
import nibabel as nib
import numpy as np
from tensordidct import TensorDict, MemmapTensor
files_by_modality = {
'flair': [...]
}
num_files = ...
# Reference shape, assuming all images have the same
shape = ...
# Initialize an empty TensorDict; fill it later
td = TensorDict({}, batch_size=len(num_files))
for modality, files in files_by_modality.items():
# The magic happens here
modality_mmap = MemmapTensor(
(len(files), *shape.shape),
dtype=torch.float32,
device='cuda'
)
# Note: the tensor is memory mapped and thus not fully loaded into memory at any given time
for i, path in enumerate(files):
arr = nib.load(path).get_fdata().astype(np.float32)
modality_mmap[i] = torch.from_numpy(arr)
td[modality] = modality_mmap
# Save to disk
td.memmap_('data.td', copy_existing=True)
# Later stage: load from disk
from monai import transforms as T
td = TensorDict.load_memmap('data.td')
# We can treat it as a normal TensorDict
# The memory mapping is handled fully transparently
modalities = list(td.keys())
transform = T.Compose([
T.RandRotateD(modalities, *[180] * 3, prob=1, mode='bilinear'),
T.CenterScaleCropD(modalities, roi_scale=[0.8, 0.8, 0.8])
])
# While `td` is a subclass of dictionary, iterating over it
# will index it along the batch dimension
for item in td:
# This will run super fast
transformed = transform(item)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment