Skip to content

Instantly share code, notes, and snippets.

@hsahovic
Created December 1, 2019 02:53
Show Gist options
  • Save hsahovic/05be90703c5aab869192caf245a9bb9b to your computer and use it in GitHub Desktop.
Save hsahovic/05be90703c5aab869192caf245a9bb9b to your computer and use it in GitHub Desktop.
Converting a Keras image generator to a Tensorflow Dataset
from typing import Generator
import tensorflow as tf
def image_generator_to_tf_ds(generator: Generator) -> tf.data.Dataset:
"""Converts a initialized keras Image Data Generator to an equivalent tf Dataset.
Example usage:
>>> img_generator = ImageDataGenerator()
>>> img_generator = img_generator.flow_from_directory(
'Training data',
target_size = (128, 128),
batch_size = 128,
)
>>> img_ds = image_generator_to_tf_ds(img_generator)
>>> model.fit(img_ds)
...
"""
def generator_wrapper():
for _ in range(generator.samples // generator.batch_size):
for x_i, y_i in zip(*next(generator)):
yield x_i, y_i
return tf.data.Dataset.from_generator(
generator_wrapper,
output_types=(tf.float32, tf.float32),
output_shapes=(
generator.image_shape,
len(generator.class_indices),
),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment