Skip to content

Instantly share code, notes, and snippets.

@yohann84L
Last active June 2, 2020 08:30
Show Gist options
  • Save yohann84L/d2bc341fa0157bb5123e0f34cb95c013 to your computer and use it in GitHub Desktop.
Save yohann84L/d2bc341fa0157bb5123e0f34cb95c013 to your computer and use it in GitHub Desktop.
GridMask augmentation for imgaug
## GridMask augmentation for imgaug
##
## Code based on this kernel https://www.kaggle.com/shivyshiv/efficientnet-gridmask-training-pytorch
from imgaug.augmenters import meta
from imgaug import parameters as iap
import imgaug.augmenters as iaa
class Rotate(iaa.Affine):
"""Apply affine rotation on the y-axis to input data.
This is a wrapper around :class:`Affine`.
It is the same as ``Affine(rotate=<value>)``.
Added in 0.4.0.
**Supported dtypes**:
See :class:`~imgaug.augmenters.geometric.Affine`.
Parameters
----------
rotate : number or tuple of number or list of number or imgaug.parameters.StochasticParameter, optional
See :class:`Affine`.
order : int or iterable of int or imgaug.ALL or imgaug.parameters.StochasticParameter, optional
See :class:`Affine`.
cval : number or tuple of number or list of number or imgaug.ALL or imgaug.parameters.StochasticParameter, optional
See :class:`Affine`.
mode : str or list of str or imgaug.ALL or imgaug.parameters.StochasticParameter, optional
See :class:`Affine`.
fit_output : bool, optional
See :class:`Affine`.
backend : str, optional
See :class:`Affine`.
seed : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional
See :func:`~imgaug.augmenters.meta.Augmenter.__init__`.
name : None or str, optional
See :func:`~imgaug.augmenters.meta.Augmenter.__init__`.
random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional
Old name for parameter `seed`.
Its usage will not yet cause a deprecation warning,
but it is still recommended to use `seed` now.
Outdated since 0.4.0.
deterministic : bool, optional
Deprecated since 0.4.0.
See method ``to_deterministic()`` for an alternative and for
details about what the "deterministic mode" actually does.
"""
# Added in 0.4.0.
def __init__(self, rotate=(-30, 30), order=1, cval=0, mode="constant",
fit_output=False, backend="auto",
seed=None, name=None,
random_state="deprecated", deterministic="deprecated"):
super(Rotate, self).__init__(
rotate=rotate,
order=order,
cval=cval,
mode=mode,
fit_output=fit_output,
backend=backend,
seed=seed, name=name,
random_state=random_state, deterministic=deterministic)
class GridMask(meta.Augmenter):
"""GridMask augmentation for image classification and object detection.
Args:
num_grid (int): number of grid in a row or column.
fill_value (int, float, lisf of int, list of float): value for dropped pixels.
rotate ((int, int) or int): range from which a random angle is picked. If rotate is a single int
an angle is picked from (-rotate, rotate). Default: (-90, 90)
mode (int):
0 - cropout a quarter of the square of each grid (left top)
1 - reserve a quarter of the square of each grid (left top)
2 - cropout 2 quarter of the square of each grid (left top & right bottom)
Targets:
image, mask
Image types:
uint8, float32
Reference:
| https://arxiv.org/abs/2001.04086
| https://github.com/akuxcw/GridMask
"""
def __init__(self, num_grid=3, fill_value=0, rotate=0, mode=0, p=1, name=None, deterministic=False, random_state=None):
super(GridMask, self).__init__(name=name, deterministic=deterministic, random_state=random_state)
if isinstance(num_grid, int):
num_grid = (num_grid, num_grid)
if isinstance(rotate, int):
rotate = (-rotate, rotate)
self.num_grid = num_grid
self.fill_value = fill_value
self.rotate = rotate
self.mode = mode
self.masks = None
self.rand_h_max = []
self.rand_w_max = []
self.p = iap.handle_probability_param(p, "p")
def init_masks(self, height, width):
if self.masks is None:
self.masks = []
n_masks = self.num_grid[1] - self.num_grid[0] + 1
for n, n_g in enumerate(range(self.num_grid[0], self.num_grid[1] + 1, 1)):
grid_h = height / n_g
grid_w = width / n_g
this_mask = np.ones((int((n_g + 1) * grid_h), int((n_g + 1) * grid_w))).astype(np.uint8)
for i in range(n_g + 1):
for j in range(n_g + 1):
this_mask[
int(i * grid_h): int(i * grid_h + grid_h / 2),
int(j * grid_w): int(j * grid_w + grid_w / 2)
] = self.fill_value
if self.mode == 2:
this_mask[
int(i * grid_h + grid_h / 2): int(i * grid_h + grid_h),
int(j * grid_w + grid_w / 2): int(j * grid_w + grid_w)
] = self.fill_value
if self.mode == 1:
this_mask = 1 - this_mask
self.masks.append(this_mask)
self.rand_h_max.append(grid_h)
self.rand_w_max.append(grid_w)
def _augment_images(self, images, random_state, parents, hooks):
nb_images = len(images)
samples = self.p.draw_samples((nb_images,), random_state=random_state)
for i, (image, sample) in enumerate(zip(images, samples)):
if sample > 0.5:
height, width = image.shape[:2]
self.init_masks(height, width)
mid = np.random.randint(len(self.masks))
mask = self.masks[mid]
rand_h = np.random.randint(self.rand_h_max[mid])
rand_w = np.random.randint(self.rand_w_max[mid])
h, w = image.shape[:2]
rot = Rotate(self.rotate)
mask = rot.augment_image(mask) if self.rotate[1] > 0 else mask
mask = mask[:, :, np.newaxis] if image.ndim == 3 else mask
image *= mask[rand_h:rand_h + h, rand_w:rand_w + w].astype(image.dtype)
images[i] = image
return images
def get_parameters(self):
return [self.p]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment