Skip to content

Instantly share code, notes, and snippets.

@renexu
Created December 28, 2017 05:46
Show Gist options
  • Star 14 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save renexu/859d05fa3df4509b676fd31bd220ec1b to your computer and use it in GitHub Desktop.
Save renexu/859d05fa3df4509b676fd31bd220ec1b to your computer and use it in GitHub Desktop.
Keras HDF5Matrix and fit_generator for huge hdf5 dataset
import threading
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import Adam
from keras.utils.io_utils import HDF5Matrix
class threadsafe_iter:
"""Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()
def threadsafe_generator(f):
"""A decorator that takes a generator function and makes it thread-safe.
"""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
@threadsafe_generator
def generator(hdf5_file, batch_size):
x = HDF5Matrix(hdf5_file, 'x')
size = x.end
y = HDF5Matrix(hdf5_file, 'y')
idx = 0
while True:
last_batch = idx + batch_size > size
end = idx + batch_size if not last_batch else size
yield x[idx:end], y[idx:end]
idx = end if not last_batch else 0
def data_statistic(train_dataset, test_dataset):
train_x = HDF5Matrix(train_dataset, 'x')
test_x = HDF5Matrix(test_dataset, 'x')
return train_x.end, test_x.end
def build_model():
m = InceptionV3(weights=None)
return m
if __name__ == '__main__':
batch_size = 32
train_dataset = 'train.h5'
test_dataset = 'test.h5'
model = build_model()
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
train_generator = generator(train_dataset, batch_size)
test_generator = generator(test_dataset, batch_size)
nb_train_samples, nb_test_samples = data_statistic(train_dataset, test_dataset)
print('train samples: %d, test samples: %d' % (nb_train_samples, nb_test_samples))
model.fit_generator(
epochs=10,
generator=train_generator, steps_per_epoch=nb_train_samples // batch_size,
validation_data=test_generator, validation_steps=nb_test_samples // batch_size,
max_queue_size=10, # use a value which can fit batch_size * image_size * max_queue_size in your CPU memory
workers=1, # I don't see multi workers can have any performance benefit without multi threading
use_multiprocessing=False, # HDF5Matrix cannot support multi-threads
shuffle=False) # you cannot shuffle on a HDF5Matrix, so make sure you shuffle the data before save to h5 file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment