Skip to content

Instantly share code, notes, and snippets.

Last active July 28, 2020 12:57
Show Gist options
  • Save sunnychugh/046280a39eb9a685090d77c675b20a2a to your computer and use it in GitHub Desktop.
Save sunnychugh/046280a39eb9a685090d77c675b20a2a to your computer and use it in GitHub Desktop.
import datetime
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def load_data_using_keras(folders):
Load the images in batches using Keras.
Shuffle images (for training set only) using keras.
Data Generator to be used while training the model.
Note: Keras might need 'pillow' library to be installed. Use-
# pip install pillow
image_generator = {}
data_generator = {}
for x in folders:
image_generator[x] = ImageDataGenerator(rescale=1./255)
shuffle_images = True if x == 'train' else False
data_generator[x] = image_generator[x].flow_from_directory(
directory=os.path.join(dir_path, x),
target_size=(img_dims[0], img_dims[1]),
return data_generator
def load_data_using_tfdata(folders):
Load the images in batches using Tensorflow (tfdata).
Cache can be used to speed up the process.
Faster method in comparison to image loading using Keras.
Data Generator to be used while training the model.
def parse_image(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
class_names = np.array(os.listdir(dir_path + '/train'))
# The second to last is the class-directory
label = parts[-2] == class_names
# load the raw data from the file as a string
img =
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# Use `convert_image_dtype` to convert to floats in the [0,1] range
img = tf.image.convert_image_dtype(img, tf.float32)
# resize the image to the desired size.
img = tf.image.resize(img, [img_dims[0], img_dims[1]])
return img, label
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
# If a small dataset, only load it once, and keep it in memory.
# use `.cache(filename)` to cache preprocessing work for datasets
# that don't fit in memory.
if cache:
if isinstance(cache, str):
ds = ds.cache(cache)
ds = ds.cache()
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
# Repeat forever
ds = ds.repeat()
ds = ds.batch(batch_size)
# `prefetch` lets the dataset fetch batches in the background
# while the model is training.
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
data_generator = {}
for x in folders:
dir_extend = dir_path + '/' + x
list_ds ='/*/*'))
# Set `num_parallel_calls` so that multiple images are
# processed in parallel
labeled_ds =
parse_image, num_parallel_calls=AUTOTUNE)
# cache = True, False, './file_name'
# If the dataset doesn't fit in memory use a cache file,
# eg. cache='./data.tfcache'
data_generator[x] = prepare_for_training(
labeled_ds, cache='./data.tfcache')
return data_generator
def timeit(ds, steps=1000):
Check performance/speed for loading images using Keras or tfdata.
start = time.time()
it = iter(ds)
for i in range(steps):
print(' >> ', i, '/1000', end='\r')
duration = time.time()-start
print(f'''{steps} batches: '''
print(f'{round(batch_size*steps/duration)} Images/s')
if __name__ == '__main__':
# Need to change this w.r.t data
dir_path = '/home/sun/data/dog_vs_cat'
folders = ['train', 'val']
load_data_using = 'tfdata'
batch_size = 32
img_dims = [256, 256]
if load_data_using == 'keras':
data_generator = load_data_using_keras(folders)
elif load_data_using == 'tfdata':
data_generator = load_data_using_tfdata(folders)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment