Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
FileListIterator for keras
import numpy as np
from keras import backend as K
from keras.preprocessing.image import Iterator, load_img, img_to_array
class FileListIterator(Iterator):
"""Iterator capable of reading images from an array of the filenames.
# Arguments
filenames: Path to the directory to read images from.
Each subdirectory in this directory will be
considered to contain images from one class,
or alternatively you could specify class subdirectories
via the `classes` argument.
fileClasses: Associated classes for each file in the file names.
It should be the same size as filenames
image_data_generator: Instance of `ImageDataGenerator`
to use for random transformations and normalization.
target_size: tuple of integers, dimensions to resize input images to.
color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images.
classes: Optional list of strings, names of subdirectories
containing images from each class (e.g. `["dogs", "cats"]`).
It will be computed automatically if not set.
class_mode: Mode for yielding the targets:
`"binary"`: binary targets (if there are only two classes),
`"categorical"`: categorical targets,
`"sparse"`: integer targets,
`"input"`: targets are images identical to input images (mainly
used to work with autoencoders),
`None`: no targets get yielded (only input images are yielded).
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
seed: Random seed for data shuffling.
data_format: String, one of `channels_first`, `channels_last`.
save_to_dir: Optional directory where to save the pictures
being yielded, in a viewable format. This is useful
for visualizing the random transformations being
applied, for debugging purposes.
save_prefix: String prefix to use for saving sample
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
subset: Subset of data (`"training"` or `"validation"`) if
validation_split is set in ImageDataGenerator.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
# Examples
```python
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
filenames = ['path/to/file1.png', 'path/to/file2.png', ...]
fileClasses = ['scottish_deerhound', 'entlebucher', ...]
fileListIterator = FileListIterator(
filenames,
fileClasses,
train_datagen,
target_size=(256, 256),
color_mode='grayscale',
classes=None,
class_mode='categorical',
data_format=train_datagen.data_format,
batch_size=32,
shuffle=True,
seed=None,
save_to_dir=None,
save_prefix='',
save_format='png',
follow_links=False,
subset=None,
interpolation='nearest')
```
"""
def __init__(self,
filenames,
fileClasses,
image_data_generator,
target_size=(256, 256),
color_mode='rgb',
classes=None,
class_mode='categorical',
batch_size=32,
shuffle=True,
seed=None,
data_format=None,
save_to_dir=None,
save_prefix='',
save_format='png',
follow_links=False,
subset=None,
interpolation='nearest'):
if data_format is None:
data_format = K.image_data_format()
# self.directory = directory
self.image_data_generator = image_data_generator
self.target_size = tuple(target_size)
if color_mode not in {'rgb', 'grayscale'}:
raise ValueError('Invalid color mode:', color_mode,
'; expected "rgb" or "grayscale".')
self.color_mode = color_mode
self.data_format = data_format
if self.color_mode == 'rgb':
if self.data_format == 'channels_last':
self.image_shape = self.target_size + (3,)
else:
self.image_shape = (3,) + self.target_size
else:
if self.data_format == 'channels_last':
self.image_shape = self.target_size + (1,)
else:
self.image_shape = (1,) + self.target_size
self.classes = classes
if class_mode not in {'categorical', 'binary', 'sparse',
'input', None}:
raise ValueError('Invalid class_mode:', class_mode,
'; expected one of "categorical", '
'"binary", "sparse", "input"'
' or None.')
self.class_mode = class_mode
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
self.interpolation = interpolation
if subset is not None:
validation_split = self.image_data_generator._validation_split
if subset == 'validation':
split = (0, validation_split)
elif subset == 'training':
split = (validation_split, 1)
else:
raise ValueError('Invalid subset name: ', subset,
'; expected "training" or "validation"')
else:
split = None
self.subset = subset
white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff'}
# first, count the number of samples and classes
self.samples = 0
if not classes:
classes = list(set(fileClasses))
self.num_classes = len(classes)
self.class_indices = dict(zip(classes, range(len(classes))))
self.samples = len(filenames)
print('Found %d images belonging to %d classes.' % (self.samples, self.num_classes))
# second, build an index of the images in the different class subfolders
results = []
self.filenames = filenames
# self.fileClasses = fileClasses
self.classes = np.zeros((self.samples,), dtype='int32')
i = 0
for category in fileClasses:
self.classes[i] = self.class_indices[category]
i+=1
super(FileListIterator, self).__init__(self.samples, batch_size, shuffle, seed)
def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx())
grayscale = self.color_mode == 'grayscale'
# build batch of image data
for i, j in enumerate(index_array):
filename = self.filenames[j]
img = load_img(filename,
grayscale=grayscale,
target_size=self.target_size,
interpolation=self.interpolation)
x = img_to_array(img, data_format=self.data_format)
x = self.image_data_generator.random_transform(x)
x = self.image_data_generator.standardize(x)
batch_x[i] = x
# optionally save augmented images to disk for debugging purposes
if self.save_to_dir:
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
filename = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
index=j,
hash=np.random.randint(1e7),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, filename))
# build batch of labels
if self.class_mode == 'input':
batch_y = batch_x.copy()
elif self.class_mode == 'sparse':
batch_y = self.classes[index_array]
elif self.class_mode == 'binary':
batch_y = self.classes[index_array].astype(K.floatx())
elif self.class_mode == 'categorical':
batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx())
for i, label in enumerate(self.classes[index_array]):
batch_y[i, label] = 1.
else:
return batch_x
return batch_x, batch_y
def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
with self.lock:
index_array = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment