Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active October 11, 2020 23:32
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 7 You must be signed in to fork a gist
  • Save wassname/74f02bc9134897e3fe4e60784f5aaa15 to your computer and use it in GitHub Desktop.
Save wassname/74f02bc9134897e3fe4e60784f5aaa15 to your computer and use it in GitHub Desktop.
How to do data augmentation on a keras HDF5Matrix
"""Another way, note this one will load the whole array into memory ."""
from keras.preprocessing.image import ImageDataGenerator
import h5py
from keras.utils.io_utils import HDF5Matrix
seed=0
batch_size=32
# we create two instances with the same arguments
data_gen_args = dict(
rotation_range=90.,
width_shift_range=0.05,
height_shift_range=0.05,
zoom_range=0.2,
channel_shift_range=0.005,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant',
data_format="channels_last",
)
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
X_train = HDF5Matrix(os.path.join(out_dir, 'train_X_3band.h5'), 'X')
y_train = HDF5Matrix(os.path.join(out_dir, 'train_y_3class.h5'), 'y')
image_generator = image_datagen.flow(
X_train, None,
seed=seed,
batch_size=batch_size,
)
mask_generator = mask_datagen.flow(
y_train, None,
seed=seed,
batch_size=batch_size,
)
# combine generators into one which yields image and masks
train_generator = zip(image_generator, mask_generator)
train_generator
X, y = next(train_generator)
X.shape, y.shape
"""How to do data augmentation on a keras HDF5Matrix"""
from keras.utils.io_utils import HDF5Matrix
class AugumentedHDF5Matrix(HDF5Matrix):
"""Wraps HDF5Matrixs with image augumentation."""
def __init__(self, image_datagen, seed, *args, **kwargs):
self.image_datagen = image_datagen
self.seed = seed
self.i = 0
super().__init__(*args, **kwargs)
def __getitem__(self, key):
x = super().__getitem__(key)
self.i += 1
if len(x.shape) == 3:
return self.image_datagen.random_transform(
x, seed=self.seed + self.i)
else:
return np.array([
self.image_datagen.random_transform(
xx, seed=self.seed + self.i) for xx in x
])
# Test
from keras.preprocessing.image import ImageDataGenerator
import h5py
import numpy as np
from matplotlib import pyplot as plt
# a keras imagedata generator
image_datagen = ImageDataGenerator(
width_shift_range=0.05,
height_shift_range=0.05,
zoom_range=0.1,
channel_shift_range=0.005,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant',
data_format="channels_last",
rescale=1 / 255.0)
# test h5 file
images = np.random.random((100, 244, 244, 3))
images[:, 20:30, 20:50, :] = 1
images[:, 50:70, 20:30, :] = 0
datapath = "/tmp/testfile5.hdf5"
with h5py.File(datapath, "w") as f:
dst = f.create_dataset("X", data=images)
# Test
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X')
a = X[0].mean()
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X')
b = X[0].mean()
assert a == b, 'should be repeatable'
c = X[0].mean()
assert b != c, 'and random'
# Should be able to slice
X[1:2][0]
X[[1, 2]][0]
# View
for _ in range(5):
plt.imshow(X[0])
plt.show()
@wassname
Copy link
Author

download
download 1

@helderc
Copy link

helderc commented Jul 4, 2018

Interesting code but I have one question: Is it scalable in the case where my data, inside the HDF5 file, does not fit into the memory?

@wassname
Copy link
Author

wassname commented Mar 9, 2019

Probably not. I've moved onto making multiple hdf5 files of ~400mb, then loading the whole lot as a dask array.

@muxizju
Copy link

muxizju commented Apr 27, 2019

as explained in this issue https://github.com/keras-team/keras/issues/2674#issuecomment-218036900 , the data is not loaded into the memory but read from the disk, so it is not neccesary for the hdf5 data to be small enough

@dbersan
Copy link

dbersan commented Dec 26, 2019

Ideally, to work with keras' fit_generator() function, AugumentedHDF5Matrix should be implemented as an iterator.
This might still work since you can iterate on the object (because of a legacy behavior of python, as explained here), but I wouldn't count on that...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment