Last active
December 13, 2021 17:37
-
-
Save duongntbk/8f5828f74b082d6c5136790498ab8023 to your computer and use it in GitHub Desktop.
h5py demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
''' | |
This is the demo code for my article about HDF5 in Python at the link below. | |
https://duongnt.com/hdf5-with-h5py | |
Please install these packages before running this script. | |
- pip install h5py | |
- pip install imutils | |
- pip install keras | |
- pip install tensorflow | |
- pip install opencv-contrib-python | |
''' | |
import os | |
import h5py | |
import numpy as np | |
from imutils import paths | |
from tensorflow.keras.preprocessing.image import img_to_array, load_img | |
from tensorflow.keras.preprocessing import image_dataset_from_directory | |
class HDF5Writer: | |
def __init__(self, output_path, buffer_size, dims): | |
self.output_path = output_path | |
self.buffer_size = buffer_size | |
self.dims = dims | |
self.db = h5py.File(output_path, 'w') | |
self.data = self.db.create_dataset('data', dims, dtype='float32') | |
self.labels = self.db.create_dataset('labels', dims[0], dtype = 'int') | |
self.buffer = { | |
'data': [], | |
'labels': [] | |
} | |
self.idx = 0 # Index in database | |
def write(self, data, label): | |
self.buffer['data'].append(data) | |
self.buffer['labels'].append(label) | |
# The buffer is full, write it to disk | |
if (len(self.buffer['data']) >= self.buffer_size): | |
self.flush() | |
def flush(self): | |
# Write buffer to disk | |
i = self.idx + len(self.buffer['data']) | |
self.data[self.idx:i] = self.buffer['data'] | |
self.labels[self.idx:i] = self.buffer['labels'] | |
self.idx = i | |
# Reset buffer | |
self.buffer = { | |
'data': [], | |
'labels': [] | |
} | |
def close(self): | |
# If buffer still contains data, flush it all to disk | |
if len(self.buffer['data']) > 0: | |
self.flush() | |
# Close database | |
self.db.close() | |
def create_hdf5_file(): | |
data_dir = 'pokemon_jpg' | |
data_paths = list(paths.list_images(data_dir)) | |
labels = [int(n.split(os.path.sep)[-2]) for n in data_paths] | |
output_path = 'pokemon_jpeg.hdf5' | |
buffer_size = 100 | |
dims = (len(data_paths), 256, 256, 3) | |
writer = HDF5Writer(output_path, buffer_size, dims) | |
arr = np.arange(len(data_paths)) | |
np.random.shuffle(arr) | |
for i in arr: | |
image = load_img(data_paths[i], target_size=(256,256), interpolation='bilinear') | |
image = img_to_array(image, data_format='channels_last') | |
writer.write(image, labels[i]) | |
writer.close() | |
def create_hdf5_generator(db_path, batch_size): | |
db = h5py.File(db_path) | |
db_size = db['data'].shape[0] | |
while True: | |
for i in np.arange(0, db_size, batch_size): | |
images = db['data'][i:i+batch_size] | |
labels = db['labels'][i:i+batch_size] | |
yield images, labels | |
def test_normal_read(): | |
dataset = image_dataset_from_directory( | |
directory='pokemon_jpg' | |
) | |
dataset = dataset.repeat() | |
curr = 0 | |
for sample, _ in dataset: | |
if curr >= 100: | |
break | |
curr += 1 | |
print(sample.shape) | |
def test_hdf5_read(): | |
db_path = 'pokemon_jpeg.hdf5' | |
batch_size = 32 | |
hdf5_gen = create_hdf5_generator(db_path, batch_size) | |
curr = 0 | |
for sample, _ in hdf5_gen: | |
if curr >= 100: | |
break | |
curr += 1 | |
print(sample.shape) | |
def run_benchmark(): | |
import timeit | |
dataset = image_dataset_from_directory( | |
directory='pokemon_jpg', | |
color_mode='rgb', | |
batch_size=32, | |
image_size=(256,256) | |
) | |
normal_gen = iter(dataset.repeat()) | |
db_path = 'pokemon_jpeg.hdf5' | |
batch_size = 32 | |
hdf5_gen = create_hdf5_generator(db_path, batch_size) | |
rs_normal = timeit.timeit(lambda: next(normal_gen), number=1000) | |
rs_hdf5 = timeit.timeit(lambda: next(hdf5_gen), number=1000) | |
print(f'Baseline: {rs_normal}') | |
print(f'HDF5 benchmark: {rs_hdf5}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment