Last active
December 16, 2018 16:10
-
-
Save kleysonr/a41f0d72891afec8a49990c8cc24f5e4 to your computer and use it in GitHub Desktop.
Custom generator function to be used with keras fit_generator()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math, os | |
import numpy as np | |
import cv2 | |
from sklearn.preprocessing import LabelBinarizer | |
from imutils import paths | |
import time | |
""" | |
File name: keras_batch_generator.py | |
Author: Kleyson Rios | |
Email: kleysonr@gmail.com | |
""" | |
class KerasBatchGenerator(): | |
def __init__(self, dataset_path, test_ratio=0.25, batch_size=32, imagesize=(300,300), preprocessors=[]): | |
# Dict mapping classes and images path | |
self.data = {} | |
# Dict for trai and test dataset | |
self.train_test = {} | |
# Lenght of th dataset | |
self.datasetsize = 0 | |
# Number of images sent in each chunk | |
self.batch_size = batch_size | |
# Number of images of the smallest class | |
self.minsize = math.inf | |
# Ratio from full dataset for testing | |
self.test_ratio = test_ratio | |
# Image size to feed into NN | |
self.imagesize = imagesize | |
# List of preprocessors to apply | |
self.preprocessors = preprocessors | |
# Index to control the number of images for epoch | |
self.current_idx = {'train': 0, 'test': 0} | |
# Mapping between class name and one hot enconding | |
self.onehotencoding = None | |
self.lb = LabelBinarizer() | |
# Get a list of all the images under dataset/{class}/* | |
fileslist = paths.list_images(dataset_path) | |
for file in fileslist: | |
# Extract the label | |
label = file.split(os.path.sep)[-2] | |
# Populate dict mapping | |
try: | |
self.data[label] | |
except KeyError: | |
self.data[label] = [] | |
finally: | |
self.data[label].append(file) | |
self.datasetsize += 1 | |
# Loop over each class | |
for k in self.data.keys(): | |
# Save the size of the smallest class | |
self.minsize = len(self.data[k]) if len(self.data[k]) < self.minsize else self.minsize | |
# Calculate the offset where test samples begins, based on the smallest class. | |
# Force to have balanced classes for training. | |
self.offset = int(self.minsize * (1.0 - self.test_ratio)) | |
# Create One-hot-encoding | |
classes_name = list(self.data.keys()) | |
self.onehotencoding = dict(zip(classes_name, self.lb.fit_transform(classes_name))) | |
# Split the full dataset in train and test sets | |
self.split_train_test() | |
def split_train_test(self): | |
_train = [] | |
_test = [] | |
# Loop over each class | |
for k in self.data.keys(): | |
# Shuffle the images in each class | |
items = self.data[k] | |
np.random.shuffle(items) | |
_train += items[:self.offset] | |
_test += items[self.offset:] | |
np.random.shuffle(_train) | |
np.random.shuffle(_test) | |
self.train_test['train'] = _train | |
self.train_test['test'] = _test | |
def getNumberOfClasses(self): | |
return len(self.data.keys()) | |
def getDatasetSize(self): | |
return self.datasetsize | |
def getBatchSize(self): | |
return self.batch_size | |
def getTrainingSize(self): | |
return len(self.data.keys()) * int(self.minsize * (1.0 - self.test_ratio)) | |
def getTestingSize(self): | |
return self.getDatasetSize() - self.getTrainingSize() | |
def generate(self, set='train'): | |
try: | |
assert set == 'train' or set == 'test' | |
except AssertionError as e: | |
e.args += ('Sets valid: train or test', set) | |
raise | |
datasets_size = { | |
'train': self.getTrainingSize(), | |
'test': self.getTestingSize() | |
} | |
batch = 0 | |
while True: | |
images = [] | |
labels = [] | |
for i in range(self.batch_size): | |
# Restart the index for the first image | |
if self.current_idx[set] >= datasets_size[set]: | |
self.current_idx[set] = 0 | |
print('{} --{}-- New epoch'.format(int(time.time()), set)) | |
break | |
file = self.train_test[set][self.current_idx[set]] | |
label = file.split(os.path.sep)[-2] | |
image = self._processImage(file) | |
images.append(image) | |
labels.append(self.onehotencoding[label]) | |
self.current_idx[set] += 1 | |
print('Batch: {}-{} <<{}>> {}'.format(batch, i, set, file)) | |
batch += 1 | |
yield np.array(images), np.array(labels) | |
def _processImage(self, filename): | |
# Read image | |
image = cv2.imread(filename) | |
# check to see if our preprocessors are not Empty | |
if len(self.preprocessors) > 0: | |
# loop over the preprocessors and apply each to the image | |
for p in self.preprocessors: | |
image = p.preprocess(image) | |
image = cv2.resize(image, self.imagesize, interpolation=cv2.INTER_AREA) | |
image.astype('float') | |
return image.astype('float') / 255.0 | |
def __repr__(self): | |
return('\tDataset size: {}\n\tTraining size: {}\n\tTest size: {}\n\tClasses: {}\n\tBatch Size: {}'.format(self.getDatasetSize(), self.getTrainingSize(), self.getTestingSize(), self.getNumberOfClasses(), self.getBatchSize())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How to use: