Last active
August 27, 2022 15:34
-
-
Save MSHADroo/aed17ff5436addb56d13d0db519f6212 to your computer and use it in GitHub Desktop.
Tensorflow/keras custom imageDataGenerator for siamese network that can pair positive and negative image source directory
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from keras.preprocessing.image import ImageDataGenerator | |
import math | |
class CustomDataGen(tf.keras.utils.Sequence): | |
def __init__(self, path, batch_size, target_size=(224, 224, 3), shuffle=False): | |
train_datagen = ImageDataGenerator() | |
generator_train = train_datagen.flow_from_directory(path, class_mode="sparse") | |
self.pairedImage, self.pairedLabels = self.ــmake_pairs( | |
generator_train.filepaths, generator_train.classes | |
) | |
self.batch_size = batch_size | |
self.input_size = target_size | |
self.shuffle = shuffle | |
self.datalen = len(generator_train.filepaths) | |
self.indexes = np.arange(self.datalen) | |
if self.shuffle: | |
np.random.shuffle(self.indexes) | |
def __getitem__(self, index): | |
batch_indexes = self.indexes[ | |
index * self.batch_size : (index + 1) * self.batch_size | |
] | |
x_batch = self.pairedImage[batch_indexes] | |
y_batch = self.pairedLabels[batch_indexes] | |
X, y = self.__get_data(x_batch, y_batch) | |
return X, y | |
def __len__(self): | |
return math.ceil(self.datalen / self.batch_size) | |
def on_epoch_end(self): | |
self.indexes = np.arange(self.datalen) | |
if self.shuffle: | |
np.random.shuffle(self.indexes) | |
def ــmake_pairs(self, all_paths, all_labels): | |
pairImages, pairLabels = [], [] | |
numClasees = len(np.unique(all_labels)) | |
idx = [np.where(all_labels == i)[0] for i in range(0, numClasees)] | |
for idxA in range(len(all_paths)): | |
currentImage = all_paths[idxA] | |
label = all_labels[idxA] | |
idxB = np.random.choice(idx[label]) | |
posImage = all_paths[idxB] | |
pairImages.append([currentImage, posImage]) | |
pairLabels.append([1]) | |
negIdx = np.where(all_labels != label)[0] | |
negImages = all_paths[np.random.choice(negIdx)] | |
pairImages.append([currentImage, negImages]) | |
pairLabels.append([0]) | |
return np.asarray(pairImages), np.asarray(pairLabels) | |
def __get_data(self, image_batches, label_batches): | |
X_batch = np.asarray( | |
[ | |
[ | |
self.__get_input(path[0], self.input_size), | |
self.__get_input(path[1], self.input_size), | |
] | |
for path in image_batches | |
] | |
) | |
y_batch = np.asarray(label_batches) | |
return [X_batch[:,0], X_batch[:,1]], y_batch | |
def __get_input(self, path, target_size): | |
image = tf.keras.preprocessing.image.load_img( | |
path, target_size=(target_size[0], target_size[1]) | |
) | |
image = tf.keras.preprocessing.image.img_to_array(image, dtype="uint8") | |
return image | |
traingen = CustomDataGen( | |
PATH_TO_DIRECTORY, | |
batch_size=64, | |
target_size=(48, 48, 3), | |
) | |
x, y = traingen[0] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment