Skip to content

Instantly share code, notes, and snippets.

@usr-ein
Created December 1, 2020 23:23
Show Gist options
  • Save usr-ein/0abfc5b312cd11028eb36bf9b76c5dda to your computer and use it in GitHub Desktop.
Save usr-ein/0abfc5b312cd11028eb36bf9b76c5dda to your computer and use it in GitHub Desktop.
Loading and augmenting image dataset in Keras
from collections import Counter
from typing import Tuple, Dict
from os import PathLike
from pathlib import Path
from os.path import isdir, join as join_paths
from tensorflow.python.keras.preprocessing.image import (
ImageDataGenerator,
DirectoryIterator,
)
from tensorflow.python.keras.applications.vgg19 import (
preprocess_input as vgg19_preprocessing_func,
)
from tensorflow.python.keras.applications.vgg19 import (
preprocess_input as vgg19_preprocessing_func,
)
def preprocess_function(tensor):
# vgg19_preprocessing_func handles 3D and 4D data
return vgg19_preprocessing_func(tensor)
def get_data_augmentation() -> ImageDataGenerator:
validation_split: float = 0.2
return ImageDataGenerator(
preprocessing_function=preprocess_function,
zoom_range=0.1,
shear_range=0.1,
rotation_range=7,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
zca_whitening=False, # May be useful, test beforehand
validation_split=validation_split,
)
def load_dataset_generator(
data_generator: ImageDataGenerator,
folder_path: PathLike,
window_size: int,
batch_size: int,
) -> Tuple[DirectoryIterator, DirectoryIterator]:
seed = 42 # For good luck
kwargs = dict(
directory=folder_path,
target_size=(window_size, window_size),
classes=["negative", "positive"],
class_mode="binary",
batch_size=batch_size,
shuffle=True,
seed=seed,
)
training = data_generator.flow_from_directory(**kwargs, subset="training")
validation = data_generator.flow_from_directory(**kwargs, subset="validation")
return training, validation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment