Last active
August 19, 2017 14:45
-
-
Save wassname/43eb38005a6c287e4dc64c1e9bb6db4b to your computer and use it in GitHub Desktop.
mod to Allow keras's ImageDataGenerator to have multiple channels instead of just 1,3,4. Tested with keras>=1.2.2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
mod to Allow the ImageDataGenerator to have multiple channels instead of just 1,3,4 | |
modified from https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py | |
""" | |
from keras.preprocessing.image import Iterator | |
from keras import backend as K | |
from keras.preprocessing.image import ImageDataGenerator as _ImageDataGenerator | |
from path import Path | |
from scipy import linalg | |
import numpy as np | |
class NumpyArrayIterator(Iterator): | |
def __init__(self, x, y, image_data_generator, | |
batch_size=32, shuffle=False, seed=None, | |
dim_ordering='default', | |
save_to_dir=None, save_prefix='', save_format='jpeg'): | |
if y is not None and len(x) != len(y): | |
raise ValueError('X (images tensor) and y (labels) ' | |
'should have the same length. ' | |
'Found: X.shape = %s, y.shape = %s' % | |
(np.asarray(x).shape, np.asarray(y).shape)) | |
if dim_ordering == 'default': | |
dim_ordering = K.image_dim_ordering() | |
self.x = np.asarray(x) | |
if self.x.ndim != 4: | |
raise ValueError('Input data in `NumpyArrayIterator` ' | |
'should have rank 4. You passed an array ' | |
'with shape', self.x.shape) | |
channels_axis = 3 if dim_ordering == 'tf' else 1 | |
# if self.x.shape[channels_axis] not in {1, 3, 4}: | |
# raise ValueError('NumpyArrayIterator is set to use the ' | |
# 'dimension ordering convention "' + dim_ordering + '" ' | |
# '(channels on axis ' + str(channels_axis) + '), i.e. expected ' | |
# 'either 1, 3 or 4 channels on axis ' + str(channels_axis) + '. ' | |
# 'However, it was passed an array with shape ' + str(self.x.shape) + | |
# ' (' + str(self.x.shape[channels_axis]) + ' channels).') | |
if y is not None: | |
self.y = np.asarray(y) | |
else: | |
self.y = None | |
self.image_data_generator = image_data_generator | |
self.dim_ordering = dim_ordering | |
self.save_to_dir = save_to_dir | |
self.save_prefix = save_prefix | |
self.save_format = save_format | |
super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed) | |
def next(self): | |
# for python 2.x. | |
# Keeps under lock only the mechanism which advances | |
# the indexing of each batch | |
# see http://anandology.com/blog/using-iterators-and-generators/ | |
with self.lock: | |
index_array, current_index, current_batch_size = next(self.index_generator) | |
# The transformation of images is not under thread lock | |
# so it can be done in parallel | |
batch_x = np.zeros(tuple([current_batch_size] + list(self.x.shape)[1:])) | |
for i, j in enumerate(index_array): | |
x = self.x[j] | |
x = self.image_data_generator.random_transform(x.astype('float32')) | |
x = self.image_data_generator.standardize(x) | |
batch_x[i] = x | |
if self.save_to_dir: | |
for i in range(current_batch_size): | |
img = array_to_img(batch_x[i], self.dim_ordering, scale=True) | |
fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix, | |
index=current_index + i, | |
hash=np.random.randint(1e4), | |
format=self.save_format) | |
img.save(os.path.join(self.save_to_dir, fname)) | |
if self.y is None: | |
return batch_x | |
batch_y = self.y[index_array] | |
return batch_x, batch_y | |
class ImageDataGenerator(_ImageDataGenerator): | |
def flow(self, X, y=None, batch_size=32, shuffle=True, seed=None, | |
save_to_dir=None, save_prefix='', save_format='jpeg'): | |
return NumpyArrayIterator( | |
X, y, self, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
seed=seed, | |
dim_ordering=self.dim_ordering, | |
save_to_dir=save_to_dir, | |
save_prefix=save_prefix, | |
save_format=save_format) | |
def fit(self, x, | |
augment=False, | |
rounds=1, | |
seed=None): | |
"""Required for featurewise_center, featurewise_std_normalization | |
and zca_whitening. | |
# Arguments | |
x: Numpy array, the data to fit on. Should have rank 4. | |
In case of grayscale data, | |
the channels axis should have value 1, and in case | |
of RGB data, it should have value 3. | |
augment: Whether to fit on randomly augmented samples | |
rounds: If `augment`, | |
how many augmentation passes to do over the data | |
seed: random seed. | |
# Raises | |
ValueError: in case of invalid input `x`. | |
""" | |
x = np.asarray(x) | |
if x.ndim != 4: | |
raise ValueError('Input to `.fit()` should have rank 4. ' | |
'Got array with shape: ' + str(x.shape)) | |
# if x.shape[self.channel_axis] not in {1, 3, 4}: | |
# raise ValueError( | |
# 'Expected input to be images (as Numpy array) ' | |
# 'following the dimension ordering convention "' + self.dim_ordering + '" ' | |
# '(channels on axis ' + str(self.channel_axis) + '), i.e. expected ' | |
# 'either 1, 3 or 4 channels on axis ' + str(self.channel_axis) + '. ' | |
# 'However, it was passed an array with shape ' + str(x.shape) + | |
# ' (' + str(x.shape[self.channel_axis]) + ' channels).') | |
if seed is not None: | |
np.random.seed(seed) | |
x = np.copy(x) | |
if augment: | |
ax = np.zeros(tuple([rounds * x.shape[0]] + list(x.shape)[1:])) | |
for r in range(rounds): | |
for i in range(x.shape[0]): | |
ax[i + r * x.shape[0]] = self.random_transform(x[i]) | |
x = ax | |
if self.featurewise_center: | |
self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis)) | |
broadcast_shape = [1, 1, 1] | |
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] | |
self.mean = np.reshape(self.mean, broadcast_shape) | |
x -= self.mean | |
if self.featurewise_std_normalization: | |
self.std = np.std(x, axis=(0, self.row_axis, self.col_axis)) | |
broadcast_shape = [1, 1, 1] | |
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] | |
self.std = np.reshape(self.std, broadcast_shape) | |
x /= (self.std + K.epsilon()) | |
if self.zca_whitening: | |
flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) | |
sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] | |
u, s, _ = linalg.svd(sigma) | |
self.principal_components = np.dot(np.dot(u, np.diag(1. / np.sqrt(s + 10e-7))), u.T) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This modification was accepted into Keras, so no need to use this