Skip to content

Instantly share code, notes, and snippets.

@emuccino
Created March 20, 2020 03:13
Show Gist options
  • Save emuccino/98fbc21b5bb88508eae6036e4387887b to your computer and use it in GitHub Desktop.
Save emuccino/98fbc21b5bb88508eae6036e4387887b to your computer and use it in GitHub Desktop.
#data generator class; yields batches of data for training/testing
class ImageGenerator():
def __init__(self, directory, batch_size=16, shuffle=False, max_dimension=None):
self.directories = directory
self.batch_size = batch_size
self.shuffle = shuffle
self.max_dimension = max_dimension
self.image_paths = []
self.class_labels = []
#create list of image file paths and class target labels
for class_label, class_dir in enumerate(listdir(directory)):
self.image_paths += [path.join(directory,class_dir,f) for f in listdir(path.join(directory,class_dir))]
self.class_labels += [class_label for _ in listdir(path.join(directory,class_dir))]
self.image_paths = np.array(self.image_paths)
self.class_labels = np.array(self.class_labels)
#index array for shuffling data
self.idx = np.arange(len(self.image_paths))
def __len__(self):
#number of batches in an epoch
return int(np.ceil(len(self.image_paths)/float(self.batch_size)))
def _load_image(self,img_path):
#load image from path and convert to array
img = load_img(img_path, color_mode='rgb', interpolation='nearest')
img = img_to_array(img)
#downsample image if above allowed size if specified
max_dim = max(img.shape)
if self.max_dimension:
if max_dim > self.max_dimension:
new_dim = tuple(d*self.max_dimension//max_dim for d in img.shape[1::-1])
img = resize(img, new_dim)
#scale image values
img = preprocess_input(img)
return img
def _pad_images(self,img,shape):
#pad images to match largest image in batch
img = np.pad(img,(*[((shape[i]-img.shape[i])//2,
((shape[i]-img.shape[i])//2) + ((shape[i]-img.shape[i])%2)) for i in range(2)],
(0,0)),mode='constant',constant_values=0.)
return img
def __call__(self):
#shuffle index
if self.shuffle:
np.random.shuffle(self.idx)
#generate batches
for batch in range(len(self)):
batch_image_paths = self.image_paths[self.idx[batch*self.batch_size:(batch+1)*self.batch_size]]
batch_class_labels = self.class_labels[self.idx[batch*self.batch_size:(batch+1)*self.batch_size]]
batch_images = [self._load_image(image_path) for image_path in batch_image_paths]
max_resolution = tuple(max([img.shape[i] for img in batch_images]) for i in range(2))
batch_images = np.array([self._pad_images(image,max_resolution) for image in batch_images])
yield batch_images, batch_class_labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment