Skip to content

Instantly share code, notes, and snippets.

@tdeboissiere
Created May 16, 2016 06:29
Show Gist options
  • Save tdeboissiere/195dde7fddfcf622a82a895b90d2c800 to your computer and use it in GitHub Desktop.
Save tdeboissiere/195dde7fddfcf622a82a895b90d2c800 to your computer and use it in GitHub Desktop.
Keras fit_generator speed test
from __future__ import print_function
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras.utils import generic_utils
import multiprocessing
import os
import time
def gen_dummy():
"""
Generate ~ 1 GB of dummy images
"""
arr_data = np.random.randint(0,256, (1000, 3, 224, 224))
arr_labels = np.random.randint(0, 2, 1000)
np.savez("data_dummy", **{"data": arr_data, "labels": arr_labels})
def get_model():
"""Simple convnet """
# Build a NN
model = Sequential()
model.add(Convolution2D(32, 3, 3, input_shape=(3, 224, 224)))
model.add(Activation('relu'))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(2))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adadelta')
return model
def fit_generator_test():
""" Test the speed of fit_generator """
# Load the data (not in memory)
arr = np.load("data_dummy.npz")
# Define some parameters
batch_size = 32
n_samples = 1000
max_q_size = 20
samples_per_epoch = 320
# Define a generator function
def myGenerator():
while True:
batch_index = np.random.randint(0, n_samples - batch_size)
start = batch_index
end = start + batch_size
X = arr["data"][start: end]
y = arr["labels"][start: end]
y = np_utils.to_categorical(y, nb_classes=2)
yield X, y
# Load model
model = get_model()
start = time.time()
model.fit_generator(myGenerator(),
samples_per_epoch=samples_per_epoch,
nb_epoch=1,
verbose=1,
max_q_size=max_q_size)
print("Time fit_generator for %s samples: %s" % (samples_per_epoch, time.time() - start))
def multiprocessing_test():
""" Test the speed of custom generator """
# Define some parameters
batch_size = 32
n_samples = 1000
max_q_size = 20
maxproc = 8
samples_per_epoch = 320
# Define a generator function
def myGenerator():
""" Use multiprocessing to generate batches in parallel. """
try:
queue = multiprocessing.Queue(maxsize=max_q_size)
# define producer (putting items into queue)
def producer():
try:
# Load the data (not in memory)
arr = np.load("data_dummy.npz")
batch_index = np.random.randint(0, n_samples - batch_size)
start = batch_index
end = start + batch_size
X = arr["data"][start: end]
y = arr["labels"][start: end]
y = np_utils.to_categorical(y, nb_classes=2)
# Put the data in a queue
queue.put((X, y))
except:
print("Nothing here")
processes = []
def start_process():
for i in range(len(processes), maxproc):
thread = multiprocessing.Process(target=producer)
time.sleep(0.01)
thread.start()
processes.append(thread)
# run as consumer (read items from queue, in current thread)
while True:
processes = [p for p in processes if p.is_alive()]
if len(processes) < maxproc:
start_process()
yield queue.get()
except:
print("Finishing")
for th in processes:
th.terminate()
queue.close()
raise
# Load model
model = get_model()
samples_seen = 0
start = time.time()
progbar = generic_utils.Progbar(samples_per_epoch)
print("Epoch 1/1")
for X, y in myGenerator():
train_loss = model.train_on_batch(X, y)
progbar.add(batch_size, values=[("train loss", train_loss)])
samples_seen += batch_size
if samples_seen == samples_per_epoch:
break
print("Time multiprocessing for %s samples: %s" % (samples_per_epoch, time.time() - start))
if __name__ == '__main__':
if not os.path.isfile("data_dummy.npz"):
gen_dummy()
# fit_generator_test()
multiprocessing_test()
@jskDr
Copy link

jskDr commented May 22, 2016

It is awesome. The speed is so much improved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment