Skip to content

Instantly share code, notes, and snippets.

@Nico-Curti
Last active May 5, 2024 12:53
Show Gist options
  • Save Nico-Curti/c2dd24642e5e92707ba779a3ce9ecfb8 to your computer and use it in GitHub Desktop.
Save Nico-Curti/c2dd24642e5e92707ba779a3ce9ecfb8 to your computer and use it in GitHub Desktop.
Image patch-coordinates generator without interal copies
import numpy as np
from itertools import product
class GridTiler (object):
'''
Image patch generator
A simple generator class for the extraction of image
patches, given a binary mask on which evaluate the
amount of no-null pixels. The object implements the
extraction of the patches on a complete grid of
coordinates. The input image mask is used for the
evaluation of the percentage of pixels to keep during
the extraction.
The patch coordinates are given via generator list,
following the signature of Rect2D OpenCV object.
Notes
-----
A faster implementation could be obtained using NumPy
as_strided function, BUT it requires a copy of the
entire image inside the object class. Insteed, the
current implementation provides a lighter solution
to address the same task.
Parameters
----------
grid_size : tuple
Dimension of the patch to extract.
stride : tuple
Stride to apply for the patch extraction
on the full grid.
'''
def __init__ (self, grid_size : tuple,
stride : tuple):
self._kx, self._ky = grid_size
self._sx, self._sy = stride
def extract (self, mask : np.ndarray,
toll : float = 1.0) -> iter:
'''
Generate the full list of patch-coordinates which involve
a number of no-null pixels greater than the provided
tollerance (estimated as percentage of the number of pixels)
Parameters
----------
mask : np.ndarray
Binary mask used for the patch extraction.
toll : float (default=1.0)
Tollerance for the patch acceptance.
Returns
-------
coords : tuple
Coordinate of the patch to extract, evaluated as
(y, x, y + grid_size, x + grid_size).
The coordinates are given as generator.
'''
percentage = (self._kx * self._ky) * toll
h, w = mask.shape
coords = product(range(0, h - 1, self._sx),
range(0, w - 1, self._sy)
)
for x, y in coords:
roi = mask[x : x + self._kx,
y : y + self._ky,
...
]
roi = np.where(roi != 0, 1, 0)
if np.sum(roi) >= percentage:
yield ((y, x), (y + self._ky, x + self._kx))
if __name__ == '__main__':
import cv2
tiler = GridTiler(grid_size=(256, 256), stride=(256, 256))
canvas = img.copy()
for ((y, x), (h, w)) in tiler.extract(mask, toll=0.1):
cv2.rectangle(canvas, (y, x), (h, w), (0, 255, 0), 2)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(30, 10))
ax.imshow(canvas)
ax.axis('off')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment