Skip to content

Instantly share code, notes, and snippets.

@Hdooster
Last active April 7, 2020 11:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Hdooster/f09dc50b12727bf1f3935db014fe8bac to your computer and use it in GitHub Desktop.
Save Hdooster/f09dc50b12727bf1f3935db014fe8bac to your computer and use it in GitHub Desktop.
Generates images and masks from memory,and can augment with an imgaug sequence provided as input.
class ImageMaskGenerator(Sequence):
def __init__(self, inputs, outputs, batch_size=32,
dim_after_preprocessing=None, shuffle=True,
input_channels=3, output_channels=1,
preprocessing_seq=None, augment_seq=None):
self.inputs = inputs
self.outputs = outputs
self.input_channels = input_channels
self.output_channels = output_channels
self.dim_after_preprocessing = dim_after_preprocessing or self.inputs[0].shape[:2]
self.batch_size = batch_size
self.shuffle = shuffle
self.indexes = None
self.preprocessing_seq = preprocessing_seq
self.augment_seq = augment_seq
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.ceil(len(self.inputs) / self.batch_size))
def __getitem__(self, index):
'Grabs a batch at batch_index'
batch_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
batch = self.__data_generation(batch_indexes)
return batch
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.inputs))
if self.shuffle:
np.random.shuffle(self.indexes)
def __data_generation(self, indices):
'Generates batch of images at image_indexes'
# Initialization
X = np.empty((indices.shape[0], *self.dim_after_preprocessing, self.input_channels))
Y = np.empty((indices.shape[0], *self.dim_after_preprocessing, self.output_channels))
# Generate data
for i, idx in enumerate(indices):
img_in = self.inputs[idx]
segmap_out = self.outputs[idx]
segmap_out = SegmentationMapsOnImage(segmap_out, shape=img_in.shape)
# todo: preprocessing
if self.preprocessing_seq:
img_in, segmap_out = self.preprocessing_seq(image=img_in, segmentation_maps=segmap_out)
# Augment the image and keep the mask in sync.
if self.augment_seq:
img_in, segmap_out = self.augment_seq(image=img_in, segmentation_maps=segmap_out)
X[i, ] = img_in
Y[i, ] = segmap_out.arr
assert X.shape[0] == Y.shape[0]
X = X.astype(np.uint8)
# reformat Y
Y = Y / 255.
Y = np.ceil(Y).astype(np.uint8)
return X, Y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment