Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created January 30, 2023 21:28
Show Gist options
  • Save younesbelkada/6352051c254576a8381d63fb7b649ba4 to your computer and use it in GitHub Desktop.
Save younesbelkada/6352051c254576a8381d63fb7b649ba4 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import torch
import torch.nn.functional as F
# adapted from: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/8
def torch_extract_patches(
x, patch_height, patch_width, padding=None
):
x = x.unsqueeze(0)
if padding == "SAME":
x = F.pad(x, (1, 1, 1, 1))
# patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
# Permute so that channels are next to patch dimension
patches = patches.permute(0, 2, 3, 1, 5, 4).contiguous() # [128, 32, 32, 16, 3, 3]
# View as [batch_size, height, width, channels*kh*kw]
patches = patches.reshape(*patches.size()[:3], -1)
return patches
# H x W x C
image_tf = tf.random.uniform(shape=(720, 720, 3))
image_torch = torch.from_numpy(image_tf.numpy()).permute(2, 0, 1)
patch_height, patch_width = 16, 16
patches_tf = tf.image.extract_patches(
images=tf.expand_dims(image_tf, 0),
sizes=[1, patch_height, patch_width, 1],
strides=[1, patch_height, patch_width, 1],
rates=[1, 1, 1, 1],
padding="SAME"
)
patches_torch = torch_extract_patches(
x=image_torch,
patch_height=patch_height,
patch_width=patch_width,
padding="SAME"
)
assert torch.allclose(
patches_torch.squeeze(0),
torch.from_numpy(patches_tf.numpy()[0, :, :, :]),
atol=1e-3,
rtol=1e-3
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment