Last active
April 7, 2020 11:33
-
-
Save Hdooster/f09dc50b12727bf1f3935db014fe8bac to your computer and use it in GitHub Desktop.
Generates images and masks from memory,and can augment with an imgaug sequence provided as input.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ImageMaskGenerator(Sequence): | |
def __init__(self, inputs, outputs, batch_size=32, | |
dim_after_preprocessing=None, shuffle=True, | |
input_channels=3, output_channels=1, | |
preprocessing_seq=None, augment_seq=None): | |
self.inputs = inputs | |
self.outputs = outputs | |
self.input_channels = input_channels | |
self.output_channels = output_channels | |
self.dim_after_preprocessing = dim_after_preprocessing or self.inputs[0].shape[:2] | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.indexes = None | |
self.preprocessing_seq = preprocessing_seq | |
self.augment_seq = augment_seq | |
self.on_epoch_end() | |
def __len__(self): | |
'Denotes the number of batches per epoch' | |
return int(np.ceil(len(self.inputs) / self.batch_size)) | |
def __getitem__(self, index): | |
'Grabs a batch at batch_index' | |
batch_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] | |
batch = self.__data_generation(batch_indexes) | |
return batch | |
def on_epoch_end(self): | |
'Updates indexes after each epoch' | |
self.indexes = np.arange(len(self.inputs)) | |
if self.shuffle: | |
np.random.shuffle(self.indexes) | |
def __data_generation(self, indices): | |
'Generates batch of images at image_indexes' | |
# Initialization | |
X = np.empty((indices.shape[0], *self.dim_after_preprocessing, self.input_channels)) | |
Y = np.empty((indices.shape[0], *self.dim_after_preprocessing, self.output_channels)) | |
# Generate data | |
for i, idx in enumerate(indices): | |
img_in = self.inputs[idx] | |
segmap_out = self.outputs[idx] | |
segmap_out = SegmentationMapsOnImage(segmap_out, shape=img_in.shape) | |
# todo: preprocessing | |
if self.preprocessing_seq: | |
img_in, segmap_out = self.preprocessing_seq(image=img_in, segmentation_maps=segmap_out) | |
# Augment the image and keep the mask in sync. | |
if self.augment_seq: | |
img_in, segmap_out = self.augment_seq(image=img_in, segmentation_maps=segmap_out) | |
X[i, ] = img_in | |
Y[i, ] = segmap_out.arr | |
assert X.shape[0] == Y.shape[0] | |
X = X.astype(np.uint8) | |
# reformat Y | |
Y = Y / 255. | |
Y = np.ceil(Y).astype(np.uint8) | |
return X, Y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment