Skip to content

Instantly share code, notes, and snippets.

@miladfa7
Created December 28, 2019 20:36
Show Gist options
  • Save miladfa7/d912e05bd543c678227ef0759defa2f2 to your computer and use it in GitHub Desktop.
Save miladfa7/d912e05bd543c678227ef0759defa2f2 to your computer and use it in GitHub Desktop.
ImageDataGenerator (in-place augmentation)
import tensorflow
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Create training ImageDataGenerator object
train_data_gen = ImageDataGenerator(rotation_range=50,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=0.3,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant',
cval=0,
rescale=1./255)
# Create validation ImageDataGenerator objects
valid_data_gen = ImageDataGenerator(rotation_range=45,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=0.3,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant',
cval=0,
rescale=1./255)
test_data_gen = ImageDataGenerator(rescale=1./255)
SEED = 1234
tensorflow.random.set_seed(SEED)
# Training
training_dir = os.path.join(dataset_dir, 'training')
train_gen = train_data_gen.flow_from_directory(training_dir,
target_size=(256, 256),
batch_size=Batch_size,
classes=classes,
class_mode='categorical',
shuffle=True,
seed=SEED) # targets are directly converted into one-hot vectors
# Validation
valid_dir = os.path.join(dataset_dir, 'valid')
valid_gen = valid_data_gen.flow_from_directory(valid_dir,
target_size=(256, 256),
batch_size=Batch_size,
classes=classes,
class_mode='categorical',
shuffle=False,
seed=SEED)
# Test
test_dir = os.path.join(dataset_dir, 'testing')
test_gen = test_data_gen.flow_from_directory(test_dir,
target_size=(256, 256),
batch_size=10,
shuffle=False,
seed=SEED,
class_mode=None,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment