Created
February 14, 2024 19:09
-
-
Save alisterburt/a767cc07ad9de9d2ffafbe83018fc173 to your computer and use it in GitHub Desktop.
torch sliding window view on 3D image
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import einops | |
def sliding_window_3d(tensor, window_size): | |
d, h, w = tensor.shape | |
window_depth, window_height, window_width = window_size | |
# Calculate the number of windows in each dimension | |
new_d = d - window_depth + 1 | |
new_h = h - window_height + 1 | |
new_w = w - window_width + 1 | |
# Calculate the new strides | |
original_strides = tensor.stride() | |
new_strides = (original_strides[0], original_strides[1], original_strides[2], | |
original_strides[0], original_strides[1], original_strides[2]) | |
# Create the sliding window view | |
window_view = torch.as_strided(tensor, (new_d, new_h, new_w, window_depth, window_height, window_width), | |
new_strides) | |
return window_view | |
# Example usage | |
tensor = torch.arange(32 * 32 * 32).reshape(32, 32, 32) | |
window_size = (3, 3, 3) | |
window_view = sliding_window_3d(tensor, window_size) | |
print(window_view.shape) # (30, 30, 30, 3, 3, 3) | |
# make the window 1d for median calc | |
window_view_1d = einops.rearrange(window_view, 'd h w wd wh ww -> d h w (wd wh ww)') | |
# calculate the median | |
median = torch.median(window_view_1d, dim=-1) | |
print(median.values.shape) # 30, 30, 30 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment