Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active August 19, 2017 14:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/43eb38005a6c287e4dc64c1e9bb6db4b to your computer and use it in GitHub Desktop.
Save wassname/43eb38005a6c287e4dc64c1e9bb6db4b to your computer and use it in GitHub Desktop.
mod to Allow keras's ImageDataGenerator to have multiple channels instead of just 1,3,4. Tested with keras>=1.2.2
"""
mod to Allow the ImageDataGenerator to have multiple channels instead of just 1,3,4
modified from https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
"""
from keras.preprocessing.image import Iterator
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator as _ImageDataGenerator
from path import Path
from scipy import linalg
import numpy as np
class NumpyArrayIterator(Iterator):
def __init__(self, x, y, image_data_generator,
batch_size=32, shuffle=False, seed=None,
dim_ordering='default',
save_to_dir=None, save_prefix='', save_format='jpeg'):
if y is not None and len(x) != len(y):
raise ValueError('X (images tensor) and y (labels) '
'should have the same length. '
'Found: X.shape = %s, y.shape = %s' %
(np.asarray(x).shape, np.asarray(y).shape))
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
self.x = np.asarray(x)
if self.x.ndim != 4:
raise ValueError('Input data in `NumpyArrayIterator` '
'should have rank 4. You passed an array '
'with shape', self.x.shape)
channels_axis = 3 if dim_ordering == 'tf' else 1
# if self.x.shape[channels_axis] not in {1, 3, 4}:
# raise ValueError('NumpyArrayIterator is set to use the '
# 'dimension ordering convention "' + dim_ordering + '" '
# '(channels on axis ' + str(channels_axis) + '), i.e. expected '
# 'either 1, 3 or 4 channels on axis ' + str(channels_axis) + '. '
# 'However, it was passed an array with shape ' + str(self.x.shape) +
# ' (' + str(self.x.shape[channels_axis]) + ' channels).')
if y is not None:
self.y = np.asarray(y)
else:
self.y = None
self.image_data_generator = image_data_generator
self.dim_ordering = dim_ordering
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
def next(self):
# for python 2.x.
# Keeps under lock only the mechanism which advances
# the indexing of each batch
# see http://anandology.com/blog/using-iterators-and-generators/
with self.lock:
index_array, current_index, current_batch_size = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
batch_x = np.zeros(tuple([current_batch_size] + list(self.x.shape)[1:]))
for i, j in enumerate(index_array):
x = self.x[j]
x = self.image_data_generator.random_transform(x.astype('float32'))
x = self.image_data_generator.standardize(x)
batch_x[i] = x
if self.save_to_dir:
for i in range(current_batch_size):
img = array_to_img(batch_x[i], self.dim_ordering, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
index=current_index + i,
hash=np.random.randint(1e4),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
if self.y is None:
return batch_x
batch_y = self.y[index_array]
return batch_x, batch_y
class ImageDataGenerator(_ImageDataGenerator):
def flow(self, X, y=None, batch_size=32, shuffle=True, seed=None,
save_to_dir=None, save_prefix='', save_format='jpeg'):
return NumpyArrayIterator(
X, y, self,
batch_size=batch_size,
shuffle=shuffle,
seed=seed,
dim_ordering=self.dim_ordering,
save_to_dir=save_to_dir,
save_prefix=save_prefix,
save_format=save_format)
def fit(self, x,
augment=False,
rounds=1,
seed=None):
"""Required for featurewise_center, featurewise_std_normalization
and zca_whitening.
# Arguments
x: Numpy array, the data to fit on. Should have rank 4.
In case of grayscale data,
the channels axis should have value 1, and in case
of RGB data, it should have value 3.
augment: Whether to fit on randomly augmented samples
rounds: If `augment`,
how many augmentation passes to do over the data
seed: random seed.
# Raises
ValueError: in case of invalid input `x`.
"""
x = np.asarray(x)
if x.ndim != 4:
raise ValueError('Input to `.fit()` should have rank 4. '
'Got array with shape: ' + str(x.shape))
# if x.shape[self.channel_axis] not in {1, 3, 4}:
# raise ValueError(
# 'Expected input to be images (as Numpy array) '
# 'following the dimension ordering convention "' + self.dim_ordering + '" '
# '(channels on axis ' + str(self.channel_axis) + '), i.e. expected '
# 'either 1, 3 or 4 channels on axis ' + str(self.channel_axis) + '. '
# 'However, it was passed an array with shape ' + str(x.shape) +
# ' (' + str(x.shape[self.channel_axis]) + ' channels).')
if seed is not None:
np.random.seed(seed)
x = np.copy(x)
if augment:
ax = np.zeros(tuple([rounds * x.shape[0]] + list(x.shape)[1:]))
for r in range(rounds):
for i in range(x.shape[0]):
ax[i + r * x.shape[0]] = self.random_transform(x[i])
x = ax
if self.featurewise_center:
self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))
broadcast_shape = [1, 1, 1]
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
self.mean = np.reshape(self.mean, broadcast_shape)
x -= self.mean
if self.featurewise_std_normalization:
self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))
broadcast_shape = [1, 1, 1]
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
self.std = np.reshape(self.std, broadcast_shape)
x /= (self.std + K.epsilon())
if self.zca_whitening:
flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
u, s, _ = linalg.svd(sigma)
self.principal_components = np.dot(np.dot(u, np.diag(1. / np.sqrt(s + 10e-7))), u.T)
@wassname
Copy link
Author

This modification was accepted into Keras, so no need to use this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment