Last active
May 5, 2024 12:53
-
-
Save Nico-Curti/c2dd24642e5e92707ba779a3ce9ecfb8 to your computer and use it in GitHub Desktop.
Image patch-coordinates generator without interal copies
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from itertools import product | |
class GridTiler (object): | |
''' | |
Image patch generator | |
A simple generator class for the extraction of image | |
patches, given a binary mask on which evaluate the | |
amount of no-null pixels. The object implements the | |
extraction of the patches on a complete grid of | |
coordinates. The input image mask is used for the | |
evaluation of the percentage of pixels to keep during | |
the extraction. | |
The patch coordinates are given via generator list, | |
following the signature of Rect2D OpenCV object. | |
Notes | |
----- | |
A faster implementation could be obtained using NumPy | |
as_strided function, BUT it requires a copy of the | |
entire image inside the object class. Insteed, the | |
current implementation provides a lighter solution | |
to address the same task. | |
Parameters | |
---------- | |
grid_size : tuple | |
Dimension of the patch to extract. | |
stride : tuple | |
Stride to apply for the patch extraction | |
on the full grid. | |
''' | |
def __init__ (self, grid_size : tuple, | |
stride : tuple): | |
self._kx, self._ky = grid_size | |
self._sx, self._sy = stride | |
def extract (self, mask : np.ndarray, | |
toll : float = 1.0) -> iter: | |
''' | |
Generate the full list of patch-coordinates which involve | |
a number of no-null pixels greater than the provided | |
tollerance (estimated as percentage of the number of pixels) | |
Parameters | |
---------- | |
mask : np.ndarray | |
Binary mask used for the patch extraction. | |
toll : float (default=1.0) | |
Tollerance for the patch acceptance. | |
Returns | |
------- | |
coords : tuple | |
Coordinate of the patch to extract, evaluated as | |
(y, x, y + grid_size, x + grid_size). | |
The coordinates are given as generator. | |
''' | |
percentage = (self._kx * self._ky) * toll | |
h, w = mask.shape | |
coords = product(range(0, h - 1, self._sx), | |
range(0, w - 1, self._sy) | |
) | |
for x, y in coords: | |
roi = mask[x : x + self._kx, | |
y : y + self._ky, | |
... | |
] | |
roi = np.where(roi != 0, 1, 0) | |
if np.sum(roi) >= percentage: | |
yield ((y, x), (y + self._ky, x + self._kx)) | |
if __name__ == '__main__': | |
import cv2 | |
tiler = GridTiler(grid_size=(256, 256), stride=(256, 256)) | |
canvas = img.copy() | |
for ((y, x), (h, w)) in tiler.extract(mask, toll=0.1): | |
cv2.rectangle(canvas, (y, x), (h, w), (0, 255, 0), 2) | |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(30, 10)) | |
ax.imshow(canvas) | |
ax.axis('off') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment