Created
June 2, 2020 09:16
-
-
Save yohann84L/3e8291cccebc1d2b367846eb0a5abe50 to your computer and use it in GitHub Desktop.
AutoAug for imgaug pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import numpy as np | |
from PIL import Image, ImageEnhance, ImageOps | |
class ImageNetPolicy(object): | |
def __init__(self, fillcolor=(128, 128, 128)): | |
self.policies = [ | |
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), | |
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), | |
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), | |
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), | |
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), | |
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), | |
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), | |
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), | |
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), | |
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), | |
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), | |
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), | |
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), | |
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), | |
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), | |
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), | |
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), | |
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), | |
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), | |
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), | |
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), | |
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), | |
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), | |
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), | |
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) | |
] | |
def __call__(self, img): | |
policy_idx = random.randint(0, len(self.policies) - 1) | |
return self.policies[policy_idx](img) | |
def __repr__(self): | |
return "AutoAugment ImageNet Policy" | |
class SubPolicy(object): | |
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): | |
ranges = { | |
"shearX": np.linspace(0, 0.3, 10), | |
"shearY": np.linspace(0, 0.3, 10), | |
"translateX": np.linspace(0, 150 / 331, 10), | |
"translateY": np.linspace(0, 150 / 331, 10), | |
"rotate": np.linspace(0, 30, 10), | |
"color": np.linspace(0.0, 0.9, 10), | |
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), | |
"solarize": np.linspace(256, 0, 10), | |
"contrast": np.linspace(0.0, 0.9, 10), | |
"sharpness": np.linspace(0.0, 0.9, 10), | |
"brightness": np.linspace(0.0, 0.9, 10), | |
"autocontrast": [0] * 10, | |
"equalize": [0] * 10, | |
"invert": [0] * 10 | |
} | |
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand | |
def rotate_with_fill(img, magnitude): | |
rot = img.convert("RGBA").rotate(magnitude) | |
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) | |
func = { | |
"shearX": lambda img, magnitude: img.transform( | |
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), | |
Image.BICUBIC, fillcolor=fillcolor), | |
"shearY": lambda img, magnitude: img.transform( | |
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), | |
Image.BICUBIC, fillcolor=fillcolor), | |
"translateX": lambda img, magnitude: img.transform( | |
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), | |
fillcolor=fillcolor), | |
"translateY": lambda img, magnitude: img.transform( | |
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), | |
fillcolor=fillcolor), | |
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), | |
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), | |
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), | |
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), | |
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( | |
1 + magnitude * random.choice([-1, 1])), | |
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( | |
1 + magnitude * random.choice([-1, 1])), | |
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( | |
1 + magnitude * random.choice([-1, 1])), | |
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), | |
"equalize": lambda img, magnitude: ImageOps.equalize(img), | |
"invert": lambda img, magnitude: ImageOps.invert(img) | |
} | |
self.p1 = p1 | |
self.operation1 = func[operation1] | |
self.magnitude1 = ranges[operation1][magnitude_idx1] | |
self.p2 = p2 | |
self.operation2 = func[operation2] | |
self.magnitude2 = ranges[operation2][magnitude_idx2] | |
def __call__(self, img): | |
if random.random() < self.p1: img = self.operation1(img, self.magnitude1) | |
if random.random() < self.p2: img = self.operation2(img, self.magnitude2) | |
return img | |
from imgaug.augmenters import meta | |
from imgaug import parameters as iap | |
import imgaug.augmenters as iaa | |
class AutoAug(meta.Augmenter): | |
def __init__(self, policy="imagenet", p=1, name=None, deterministic=False, random_state=None): | |
super(AutoAug, self).__init__(name=name, deterministic=deterministic, random_state=random_state) | |
if policy == "imagenet": | |
self.policy = ImageNetPolicy() | |
else: | |
raise NotImplementedError | |
self.p = iap.handle_probability_param(p, "p") | |
def _augment_images(self, images, random_state, parents, hooks): | |
nb_images = len(images) | |
samples = self.p.draw_samples((nb_images,), random_state=random_state) | |
for i, (image, sample) in enumerate(zip(images, samples)): | |
if sample > 0.5: | |
image = Image.fromarray(np.uint8(image)) | |
images[i] = np.array(self.policy(image)) | |
return images | |
def get_parameters(self): | |
return [self.p] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment