Last active
June 2, 2020 08:30
-
-
Save yohann84L/d2bc341fa0157bb5123e0f34cb95c013 to your computer and use it in GitHub Desktop.
GridMask augmentation for imgaug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## GridMask augmentation for imgaug | |
## | |
## Code based on this kernel https://www.kaggle.com/shivyshiv/efficientnet-gridmask-training-pytorch | |
from imgaug.augmenters import meta | |
from imgaug import parameters as iap | |
import imgaug.augmenters as iaa | |
class Rotate(iaa.Affine): | |
"""Apply affine rotation on the y-axis to input data. | |
This is a wrapper around :class:`Affine`. | |
It is the same as ``Affine(rotate=<value>)``. | |
Added in 0.4.0. | |
**Supported dtypes**: | |
See :class:`~imgaug.augmenters.geometric.Affine`. | |
Parameters | |
---------- | |
rotate : number or tuple of number or list of number or imgaug.parameters.StochasticParameter, optional | |
See :class:`Affine`. | |
order : int or iterable of int or imgaug.ALL or imgaug.parameters.StochasticParameter, optional | |
See :class:`Affine`. | |
cval : number or tuple of number or list of number or imgaug.ALL or imgaug.parameters.StochasticParameter, optional | |
See :class:`Affine`. | |
mode : str or list of str or imgaug.ALL or imgaug.parameters.StochasticParameter, optional | |
See :class:`Affine`. | |
fit_output : bool, optional | |
See :class:`Affine`. | |
backend : str, optional | |
See :class:`Affine`. | |
seed : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional | |
See :func:`~imgaug.augmenters.meta.Augmenter.__init__`. | |
name : None or str, optional | |
See :func:`~imgaug.augmenters.meta.Augmenter.__init__`. | |
random_state : None or int or imgaug.random.RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState, optional | |
Old name for parameter `seed`. | |
Its usage will not yet cause a deprecation warning, | |
but it is still recommended to use `seed` now. | |
Outdated since 0.4.0. | |
deterministic : bool, optional | |
Deprecated since 0.4.0. | |
See method ``to_deterministic()`` for an alternative and for | |
details about what the "deterministic mode" actually does. | |
""" | |
# Added in 0.4.0. | |
def __init__(self, rotate=(-30, 30), order=1, cval=0, mode="constant", | |
fit_output=False, backend="auto", | |
seed=None, name=None, | |
random_state="deprecated", deterministic="deprecated"): | |
super(Rotate, self).__init__( | |
rotate=rotate, | |
order=order, | |
cval=cval, | |
mode=mode, | |
fit_output=fit_output, | |
backend=backend, | |
seed=seed, name=name, | |
random_state=random_state, deterministic=deterministic) | |
class GridMask(meta.Augmenter): | |
"""GridMask augmentation for image classification and object detection. | |
Args: | |
num_grid (int): number of grid in a row or column. | |
fill_value (int, float, lisf of int, list of float): value for dropped pixels. | |
rotate ((int, int) or int): range from which a random angle is picked. If rotate is a single int | |
an angle is picked from (-rotate, rotate). Default: (-90, 90) | |
mode (int): | |
0 - cropout a quarter of the square of each grid (left top) | |
1 - reserve a quarter of the square of each grid (left top) | |
2 - cropout 2 quarter of the square of each grid (left top & right bottom) | |
Targets: | |
image, mask | |
Image types: | |
uint8, float32 | |
Reference: | |
| https://arxiv.org/abs/2001.04086 | |
| https://github.com/akuxcw/GridMask | |
""" | |
def __init__(self, num_grid=3, fill_value=0, rotate=0, mode=0, p=1, name=None, deterministic=False, random_state=None): | |
super(GridMask, self).__init__(name=name, deterministic=deterministic, random_state=random_state) | |
if isinstance(num_grid, int): | |
num_grid = (num_grid, num_grid) | |
if isinstance(rotate, int): | |
rotate = (-rotate, rotate) | |
self.num_grid = num_grid | |
self.fill_value = fill_value | |
self.rotate = rotate | |
self.mode = mode | |
self.masks = None | |
self.rand_h_max = [] | |
self.rand_w_max = [] | |
self.p = iap.handle_probability_param(p, "p") | |
def init_masks(self, height, width): | |
if self.masks is None: | |
self.masks = [] | |
n_masks = self.num_grid[1] - self.num_grid[0] + 1 | |
for n, n_g in enumerate(range(self.num_grid[0], self.num_grid[1] + 1, 1)): | |
grid_h = height / n_g | |
grid_w = width / n_g | |
this_mask = np.ones((int((n_g + 1) * grid_h), int((n_g + 1) * grid_w))).astype(np.uint8) | |
for i in range(n_g + 1): | |
for j in range(n_g + 1): | |
this_mask[ | |
int(i * grid_h): int(i * grid_h + grid_h / 2), | |
int(j * grid_w): int(j * grid_w + grid_w / 2) | |
] = self.fill_value | |
if self.mode == 2: | |
this_mask[ | |
int(i * grid_h + grid_h / 2): int(i * grid_h + grid_h), | |
int(j * grid_w + grid_w / 2): int(j * grid_w + grid_w) | |
] = self.fill_value | |
if self.mode == 1: | |
this_mask = 1 - this_mask | |
self.masks.append(this_mask) | |
self.rand_h_max.append(grid_h) | |
self.rand_w_max.append(grid_w) | |
def _augment_images(self, images, random_state, parents, hooks): | |
nb_images = len(images) | |
samples = self.p.draw_samples((nb_images,), random_state=random_state) | |
for i, (image, sample) in enumerate(zip(images, samples)): | |
if sample > 0.5: | |
height, width = image.shape[:2] | |
self.init_masks(height, width) | |
mid = np.random.randint(len(self.masks)) | |
mask = self.masks[mid] | |
rand_h = np.random.randint(self.rand_h_max[mid]) | |
rand_w = np.random.randint(self.rand_w_max[mid]) | |
h, w = image.shape[:2] | |
rot = Rotate(self.rotate) | |
mask = rot.augment_image(mask) if self.rotate[1] > 0 else mask | |
mask = mask[:, :, np.newaxis] if image.ndim == 3 else mask | |
image *= mask[rand_h:rand_h + h, rand_w:rand_w + w].astype(image.dtype) | |
images[i] = image | |
return images | |
def get_parameters(self): | |
return [self.p] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment