Skip to content

Instantly share code, notes, and snippets.

@alisterburt
Created February 14, 2024 19:09
Show Gist options
  • Save alisterburt/a767cc07ad9de9d2ffafbe83018fc173 to your computer and use it in GitHub Desktop.
Save alisterburt/a767cc07ad9de9d2ffafbe83018fc173 to your computer and use it in GitHub Desktop.
torch sliding window view on 3D image
import torch
import einops
def sliding_window_3d(tensor, window_size):
d, h, w = tensor.shape
window_depth, window_height, window_width = window_size
# Calculate the number of windows in each dimension
new_d = d - window_depth + 1
new_h = h - window_height + 1
new_w = w - window_width + 1
# Calculate the new strides
original_strides = tensor.stride()
new_strides = (original_strides[0], original_strides[1], original_strides[2],
original_strides[0], original_strides[1], original_strides[2])
# Create the sliding window view
window_view = torch.as_strided(tensor, (new_d, new_h, new_w, window_depth, window_height, window_width),
new_strides)
return window_view
# Example usage
tensor = torch.arange(32 * 32 * 32).reshape(32, 32, 32)
window_size = (3, 3, 3)
window_view = sliding_window_3d(tensor, window_size)
print(window_view.shape) # (30, 30, 30, 3, 3, 3)
# make the window 1d for median calc
window_view_1d = einops.rearrange(window_view, 'd h w wd wh ww -> d h w (wd wh ww)')
# calculate the median
median = torch.median(window_view_1d, dim=-1)
print(median.values.shape) # 30, 30, 30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment