Skip to content

Instantly share code, notes, and snippets.

@rstml
Last active June 9, 2023 06:02
Show Gist options
  • Save rstml/bbd491287efc24133b90d4f7f3663905 to your computer and use it in GitHub Desktop.
Save rstml/bbd491287efc24133b90d4f7f3663905 to your computer and use it in GitHub Desktop.
Keras center and random crop support for ImageDataGenerator

Keras center and random crop support for ImageDataGenerator

preprocess_crop.py script below adds center and random crop to Keras's flow_from_directory data generator.

It first resizes image preserving aspect ratio and then performs crop. Resized image size is based on crop_fraction which is hardcoded but can be changed. See crop_fraction = 0.875 line where 0.875 appears to be the most common, e.g. 224px crop from 256px image.

Note that the implementation has been done by monkey patching keras_preprocessing.image.utils.loag_img function as I couldn't find any other way to perform crop before resizing without rewriting many other classes above.

Due to these limitations, the cropping method is enumerated into the interpolation field. Methods are delimited by : where the first part is interpolation and second is crop e.g. lanczos:random. Supported crop methods are none, center, random. When no crop method is specified, none is assumed.

How to use it

Just drop the preprocess_crop.py into your project to enable cropping. The example below shows how you can use random cropping for the training and center cropping for validation:

import preprocess_crop
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.inception_v3 import preprocess_input

#...

# Training with random crop

train_datagen = ImageDataGenerator(
    rotation_range=20,
    channel_shift_range=20,
    horizontal_flip=True,
    preprocessing_function=preprocess_input
)

train_img_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size = (IMG_SIZE, IMG_SIZE),
    batch_size  = BATCH_SIZE,
    class_mode  = 'categorical',
    interpolation = 'lanczos:random', # <--------- random crop
    shuffle = True
)

# Validation with center crop

validate_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

validate_img_generator = validate_datagen.flow_from_directory(
    validate_dir,
    target_size = (IMG_SIZE, IMG_SIZE),
    batch_size  = BATCH_SIZE,
    class_mode  = 'categorical',
    interpolation = 'lanczos:center', # <--------- center crop
    shuffle = False
)
import random
import keras_preprocessing.image
def load_and_crop_img(path, grayscale=False, color_mode='rgb', target_size=None,
interpolation='nearest'):
"""Wraps keras_preprocessing.image.utils.loag_img() and adds cropping.
Cropping method enumarated in interpolation
# Arguments
path: Path to image file.
color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb".
The desired image format.
target_size: Either `None` (default to original size)
or tuple of ints `(img_height, img_width)`.
interpolation: Interpolation and crop methods used to resample and crop the image
if the target size is different from that of the loaded image.
Methods are delimited by ":" where first part is interpolation and second is crop
e.g. "lanczos:random".
Supported interpolation methods are "nearest", "bilinear", "bicubic", "lanczos",
"box", "hamming" By default, "nearest" is used.
Supported crop methods are "none", "center", "random".
# Returns
A PIL Image instance.
# Raises
ImportError: if PIL is not available.
ValueError: if interpolation method is not supported.
"""
# Decode interpolation string. Allowed Crop methods: none, center, random
interpolation, crop = interpolation.split(":") if ":" in interpolation else (interpolation, "none")
if crop == "none":
return keras_preprocessing.image.utils.load_img(path,
grayscale=grayscale,
color_mode=color_mode,
target_size=target_size,
interpolation=interpolation)
# Load original size image using Keras
img = keras_preprocessing.image.utils.load_img(path,
grayscale=grayscale,
color_mode=color_mode,
target_size=None,
interpolation=interpolation)
# Crop fraction of total image
crop_fraction = 0.875
target_width = target_size[1]
target_height = target_size[0]
if target_size is not None:
if img.size != (target_width, target_height):
if crop not in ["center", "random"]:
raise ValueError('Invalid crop method {} specified.', crop)
if interpolation not in keras_preprocessing.image.utils._PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(interpolation,
", ".join(keras_preprocessing.image.utils._PIL_INTERPOLATION_METHODS.keys())))
resample = keras_preprocessing.image.utils._PIL_INTERPOLATION_METHODS[interpolation]
width, height = img.size
# Resize keeping aspect ratio
# result shold be no smaller than the targer size, include crop fraction overhead
target_size_before_crop = (target_width/crop_fraction, target_height/crop_fraction)
ratio = max(target_size_before_crop[0] / width, target_size_before_crop[1] / height)
target_size_before_crop_keep_ratio = int(width * ratio), int(height * ratio)
img = img.resize(target_size_before_crop_keep_ratio, resample=resample)
width, height = img.size
if crop == "center":
left_corner = int(round(width/2)) - int(round(target_width/2))
top_corner = int(round(height/2)) - int(round(target_height/2))
return img.crop((left_corner, top_corner, left_corner + target_width, top_corner + target_height))
elif crop == "random":
left_shift = random.randint(0, int((width - target_width)))
down_shift = random.randint(0, int((height - target_height)))
return img.crop((left_shift, down_shift, target_width + left_shift, target_height + down_shift))
return img
# Monkey patch
keras_preprocessing.image.iterator.load_img = load_and_crop_img
@crearo
Copy link

crearo commented Jun 14, 2019

Thanks a bunch for this! Why isn't this a part of Keras already?

@liwei46
Copy link

liwei46 commented Oct 16, 2019

hello, how to use the code when using tf.keras? Thanks.

@nicemanis
Copy link

Thanks for the script. Is it possible to synchronize the cropping between zipped generators? I am using a generator for both input image and segmentation mask. After inspecting the output, I noticed that the mask does not align the cropped input image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment