Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def extract_patches(imgs, landmarks, patch_shape):
""" Extracts patches from an image.
Args:
imgs: a numpy array of dimensions [batch_size, width, height, channels]
landmarks: a numpy array of dimensions [num_patches, 2]
patch_shape: (width, height)
Returns:
a numpy array [num_patches, width, height, channels]
"""
patch_shape = np.array(patch_shape)
patch_half_shape = np.require(np.round(patch_shape / 2), dtype=int)
start = -patch_half_shape
end = patch_half_shape
sampling_grid = np.mgrid[start[0]:end[0], start[1]:end[1]]
sampling_grid = sampling_grid.swapaxes(0, 2).swapaxes(0, 1)
list_patches = []
for i in range(imgs.shape[0]):
img, ldm = imgs[i], landmarks[i]
img = img.transpose(2, 0, 1)
max_x = img.shape[-2] - 1
max_y = img.shape[-1] - 1
patch_grid = (sampling_grid[None, :, :, :] + ldm[:, None, None, :]
).astype('int32')
X = patch_grid[:, :, :, 0].clip(0, max_x)
Y = patch_grid[:, :, :, 1].clip(0, max_y)
patches = img[:, Y, X].transpose(1, 3, 2, 0)
list_patches.append(patches)
# # Plot for debugging
# plt.figure()
# plt.imshow(img[0, :, :], cmap="gray")
# plt.scatter(ldm[:, 0], ldm[:, 1])
# gs = gridspec.GridSpec(5, 1)
# fig = plt.figure(figsize=(15, 15))
# for i in range(5):
# ax = plt.subplot(gs[i])
# ax.imshow(patches[i, :, :, 0], cmap="gray")
# gs.tight_layout(fig)
# plt.show()
return np.array(list_patches).astype(np.float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment