Skip to content

Instantly share code, notes, and snippets.

@innat
Created January 23, 2024 10:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save innat/b6ede34630e4a2988c968467f6d3facb to your computer and use it in GitHub Desktop.
Save innat/b6ede34630e4a2988c968467f6d3facb to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.keras import layers
H_AXIS = -3
W_AXIS = -2
class RandomCutout(layers.Layer):
"""Randomly cut out rectangles from images and fill them.
Args:
height_factor: A a single float. `height_factor` controls the size of the
cutouts. `height_factor=0.0` means the rectangle will be of size 0%
of the image height, `height_factor=0.1` means the rectangle will
have a size of 10% of the image height, and so forth.
width_factor: A a single float. `width_factor` controls the size of the
cutouts. `height_factor=0.0` means the rectangle will be of size 0%
of the image height, `height_factor=0.1` means the rectangle will
have a size of 10% of the image height, and so forth.
fill_mode: Pixels inside the patches are filled according to the given
mode (one of `{"constant", "gaussian_noise"}`).
- *constant*: Pixels are filled with the same constant value.
- *gaussian_noise*: Pixels are filled with random gaussian noise.
fill_value: a float represents the value to be filled inside the patches
when `fill_mode="constant"`.
seed: Integer. Used to create a random seed.
Sample usage:
```python
(images, labels), _ = load_data()
random_cutout = RandomCutout(0.5, 0.5)
augmented_images = random_cutout(images)
```
# Disclaimer
Original implementation: https://github.com/keras-team/keras-cv.
The original implementaiton provide more interface to apply mixup on
various CV related task, i.e. object detection etc. It also provides
many effective validation check.
Derived and modified for simpler usages: M.Innat.
Ref. https://gist.github.com/innat/b6ede34630e4a2988c968467f6d3facb
"""
def __init__(
self,
height_factor,
width_factor,
fill_mode="constant",
fill_value=0.0,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
self.height_factor = height_factor
self.width_factor = width_factor,
self.fill_mode = fill_mode
self.fill_value = fill_value
self.seed = seed
if fill_mode not in ["gaussian_noise", "constant"]:
raise ValueError(
'`fill_mode` should be "gaussian_noise" '
f'or "constant". Got `fill_mode`={fill_mode}'
)
def get_random_transformation_batch(self, images, **kwargs):
centers_x, centers_y = self._compute_rectangle_position(images)
rectangles_height, rectangles_width = self._compute_rectangle_size(
images
)
return {
"centers_x": centers_x,
"centers_y": centers_y,
"rectangles_height": rectangles_height,
"rectangles_width": rectangles_width,
}
def fill_rectangle(self, images, centers_x, centers_y, widths, heights, fill_values):
"""Fill rectangles with fill value into images.
Args:
images: Tensor of images to fill rectangles into
centers_x: Tensor of positions of the rectangle centers on the x-axis
centers_y: Tensor of positions of the rectangle centers on the y-axis
widths: Tensor of widths of the rectangles
heights: Tensor of heights of the rectangles
fill_values: Tensor with same shape as images to get rectangle fill from
Returns:
images with filled rectangles.
"""
images_shape = tf.shape(images)
images_height = images_shape[1]
images_width = images_shape[2]
xywh = tf.stack([centers_x, centers_y, widths, heights], axis=1)
xywh = tf.cast(xywh, tf.float32)
corners = self.convert_format(
xywh
)
mask_shape = (images_width, images_height)
is_rectangle = self.corners_to_mask(corners, mask_shape)
is_rectangle = tf.expand_dims(is_rectangle, -1)
images = tf.where(is_rectangle, fill_values, images)
return images
def convert_format(self, boxes):
boxes = tf.cast(boxes, dtype=tf.float32)
x, y, width, height, rest = tf.split(boxes, [1, 1, 1, 1, -1], axis=-1)
results = tf.concat(
[
x - width / 2.0,
y - height / 2.0,
x + width / 2.0,
y + height / 2.0,
rest,
],
axis=-1,
)
return results
def call(self, images, **kwargs):
transformations = self.get_random_transformation_batch(
images, **kwargs
)
"""Apply random cutout."""
centers_x, centers_y = (
transformations["centers_x"],
transformations["centers_y"],
)
rectangles_height, rectangles_width = (
transformations["rectangles_height"],
transformations["rectangles_width"],
)
rectangles_fill = self._compute_rectangle_fill(images)
images = self.fill_rectangle(
images,
centers_x,
centers_y,
rectangles_width,
rectangles_height,
rectangles_fill,
)
return images
def _get_image_shape(self, images):
batch_size = tf.shape(images)[0]
heights = tf.repeat(tf.shape(images)[H_AXIS], repeats=[batch_size])
heights = tf.reshape(heights, shape=(-1,))
widths = tf.repeat(tf.shape(images)[W_AXIS], repeats=[batch_size])
widths = tf.reshape(widths, shape=(-1,))
return tf.cast(heights, dtype=tf.int32), tf.cast(widths, dtype=tf.int32)
def _compute_rectangle_position(self, inputs):
batch_size = tf.shape(inputs)[0]
heights, widths = self._get_image_shape(inputs)
# generate values in float32 and then cast (i.e. round) to int32 because
# random.uniform do not support maxval broadcasting for integer types.
# Needed because maxval is a 1-D tensor to support ragged inputs.
heights = tf.cast(heights, dtype=tf.float32)
widths = tf.cast(widths, dtype=tf.float32)
center_x = tf.random.uniform(
(batch_size,), 0, widths, dtype=tf.float32
)
center_y = tf.random.uniform(
(batch_size,), 0, heights, dtype=tf.float32
)
center_x = tf.cast(center_x, tf.int32)
center_y = tf.cast(center_y, tf.int32)
return center_x, center_y
def _compute_rectangle_size(self, inputs):
batch_size = tf.shape(inputs)[0]
images_heights, images_widths = self._get_image_shape(inputs)
height = self.height_factor
width = self.width_factor
height = height * tf.cast(images_heights, tf.float32)
width = width * tf.cast(images_widths, tf.float32)
height = tf.cast(tf.math.ceil(height), tf.int32)
width = tf.cast(tf.math.ceil(width), tf.int32)
height = tf.minimum(height, images_heights)
width = tf.minimum(width, images_heights)
return height, width
def _compute_rectangle_fill(self, inputs):
input_shape = tf.shape(inputs)
if self.fill_mode == "constant":
fill_value = tf.fill(input_shape, self.fill_value)
fill_value = tf.cast(fill_value, dtype=self.compute_dtype)
else:
# gaussian noise
fill_value = tf.random.normal(input_shape, dtype=self.compute_dtype)
# rescale the random noise to the original image range
image_max = tf.reduce_max(inputs)
image_min = tf.reduce_min(inputs)
fill_max = tf.reduce_max(fill_value)
fill_min = tf.reduce_min(fill_value)
fill_value = (image_max - image_min) * (fill_value - fill_min) / (
fill_max - fill_min
) + image_min
return fill_value
def _axis_mask(self, starts, ends, mask_len):
# index range of axis
batch_size = tf.shape(starts)[0]
axis_indices = tf.range(mask_len, dtype=starts.dtype)
axis_indices = tf.expand_dims(axis_indices, 0)
axis_indices = tf.tile(axis_indices, [batch_size, 1])
# mask of index bounds
axis_mask = tf.greater_equal(axis_indices, starts) & tf.less(axis_indices, ends)
return axis_mask
def corners_to_mask(self, bounding_boxes, mask_shape):
mask_width, mask_height = mask_shape
x0, y0, x1, y1 = tf.split(bounding_boxes, [1, 1, 1, 1], axis=-1)
w_mask = self._axis_mask(x0, x1, mask_width)
h_mask = self._axis_mask(y0, y1, mask_height)
w_mask = tf.expand_dims(w_mask, axis=1)
h_mask = tf.expand_dims(h_mask, axis=2)
masks = tf.logical_and(w_mask, h_mask)
return masks
def get_config(self):
config = super().get_config()
config.update(
{
"height_factor": self.height_factor,
"width_factor": self.width_factor,
"fill_mode": self.fill_mode,
"fill_value": self.fill_value,
"seed": self.seed,
}
)
return config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment