Skip to content

Instantly share code, notes, and snippets.

@innat
Created May 29, 2023 15:20
Show Gist options
  • Save innat/46a035ebc8997c7ae25c19a57de88a51 to your computer and use it in GitHub Desktop.
Save innat/46a035ebc8997c7ae25c19a57de88a51 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow import keras
class ColorJitter(keras.layers.Layer):
def __init__(
self,
brightness_factor=0.5,
contrast_factor=(0.5, 0.9),
saturation_factor=(0.5, 0.9),
hue_factor=0.5,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
self.seed = seed
self.brightness_factor = self._check_factor_limit(
brightness_factor, name="brightness"
)
self.contrast_factor = self._check_factor_limit(
contrast_factor, name="contrast"
)
self.saturation_factor = self._check_factor_limit(
saturation_factor, name="saturation"
)
self.hue_factor = self._check_factor_limit(hue_factor, name="hue")
def _check_factor_limit(self, factor, name):
if isinstance(factor, (int, float)):
if factor < 0:
raise TypeError(
"The factor value should be non-negative scalar or tuple "
f"or list of two upper and lower bound number. Received: {factor}"
)
if name == "brightness" or name == "hue":
return abs(factor)
return (0, abs(factor))
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
if name == "brightness" or name == "hue":
raise ValueError(
"The factor limit for brightness and hue, it should be a single "
f"non-negative scaler. Received: {factor} for {name}"
)
return sorted(factor)
else:
raise TypeError(
"The factor value should be non-negative scalar or tuple "
f"or list of two upper and lower bound number. Received: {factor}"
)
def _color_jitter(self, images):
original_dtype = images.dtype
images = tf.cast(images, dtype=tf.float32)
brightness = tf.image.random_brightness(
images, max_delta=self.brightness_factor * 255.0, seed=self.seed
)
brightness = tf.clip_by_value(brightness, 0.0, 255.0)
contrast = tf.image.random_contrast(
brightness,
lower=self.contrast_factor[0],
upper=self.contrast_factor[1],
seed=self.seed,
)
saturation = tf.image.random_saturation(
contrast,
lower=self.saturation_factor[0],
upper=self.saturation_factor[1],
seed=self.seed,
)
hue = tf.image.random_hue(saturation, max_delta=self.hue_factor, seed=self.seed)
return tf.cast(hue, original_dtype)
def call(self, images, training=True):
if training:
return self._color_jitter(images)
else:
return images
def get_config(self):
config = super().get_config()
config.update(
{
"brightness_factor": self.brightness_factor,
"contrast_factor": self.contrast_factor,
"saturation_factor": self.saturation_factor,
"hue_factor": self.hue_factor,
"seed": self.seed,
}
)
return config
images = tf.ones(shape=(10, 224, 224, 3))
cjit_image = ColorJitter()(images)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment