Skip to content

Instantly share code, notes, and snippets.

@mayureshagashe2105
Last active September 2, 2022 13:33
Show Gist options
  • Save mayureshagashe2105/d06d7161a7ebe7c4a682ff44cde7dde6 to your computer and use it in GitHub Desktop.
Save mayureshagashe2105/d06d7161a7ebe7c4a682ff44cde7dde6 to your computer and use it in GitHub Desktop.
Custom Data Generators
class DataGenerator(tf.keras.utils.Sequence):
"""Generates batches of valid patches of the images and infers the corresponding labels for those patches via given mask encoding.
Args:
img_dir: str. Directory path where all the training images are present.
mask_dir: str. Directory path where all the training masks are present.
df_images: pd.DataFrame. Pandas DataFrame object that contains information about each image in the image directory.
batch_size: int. Size of the batches of data.
patch_size: tuple. (x-length, y-length). Image will be divided into patches of the desired size.
``Example : (128, 128), Output will be of shape (-1, 128, 128, channels).``
level: int. Should be in the range[0, 2]. Indicates from which resolution level the patch has to be extracted.
tissue_threshold: float. Should be in range[0, 1]. percentage of tissue region in a patch to consider that patch to be valid.
write_to_disk: bool. Default: False. If True, Writes extracted patches to a directory upon
iteration and returns a `tf.keras.preprocessing.image.ImageDataGenerator`
instance pointing to that directory.
save_format: str. Default: None. Extension of the file to store on the disk (should be written with the `.`).
``Example: '.png', '.jpg', etc.``
save_prefix = str. Default: "". Prefix to attach to filenames before saving to the disk.
is_training: bool. Default: True. If true, images for each batch will be chosen at random. If False, images are taken in order.
Raises:
ValueError: If `level` is not in range [0, 2].
ValueError: If `tissue_threshold` is not a fraction in range [0, 1].
AssertionError: If `patch_size` is not a tupple with length = 2.
TypeError: If `save_format` is None while `write_to_disk` is True.
"""
# Class Attributes
_base_dir = '/content/patches_and_labels'
_store_dir = ""
_benign_dir = ""
_malignant_dir = ""
n_objs = 0
def __init__(self,
img_dir: str,
mask_dir: str,
df_images: pd.DataFrame,
batch_size: int,
patch_size: tuple,
level: int,
tissue_threshold = 0.50,
write_to_disk = False,
save_format = None,
save_prefix = "",
is_training = True,
):
if 0 > level > 2:
raise ValueError(f"level should be in the range [0, 2].")
if 0 > tissue_threshold > 1:
raise ValueError(f"tissue_threshold should be in the range [0, 1].")
assert len(patch_size) == 2, "length of `patch_size` should be equal to 2."
if write_to_disk:
if save_format is None:
raise TypeError("save format should not be None")
if not os.path.exists(DataGenerator._base_dir + repr(self)):
os.mkdir(DataGenerator._base_dir + repr(self))
if is_training:
DataGenerator._store_dir = f'{DataGenerator._base_dir + repr(self)}/train'
else:
DataGenerator._store_dir = f'{DataGenerator._base_dir + repr(self)}/validate'
if not os.path.exists(DataGenerator._store_dir):
os.mkdir(DataGenerator._store_dir)
DataGenerator._benign_dir = f'{DataGenerator._store_dir}/benign'
if not os.path.exists(DataGenerator._benign_dir):
os.mkdir(DataGenerator._benign_dir)
DataGenerator._malignant_dir = f'{DataGenerator._store_dir}/malignant'
if not os.path.exists(DataGenerator._malignant_dir):
os.mkdir(DataGenerator._malignant_dir)
self.img_dir = img_dir
self.mask_dir = mask_dir
self.df = df_images
self.patch_size = patch_size
self.batch_size = batch_size
self.level = level
self.thresh = tissue_threshold
self.training = is_training
self.write_to_disk = write_to_disk
self.save_format = save_format
self.save_prefix = save_prefix
self.down_sampling_scale = None
self.indices = range(len(self.df))
self.indexed_df = self.df.copy()
self.indexed_df.set_index('image_id', inplace=True)
DataGenerator.n_objs += 1
def __repr__(self):
"""Returns the string representation of any object.
Returns:
Custom string representation of an instance.
"""
return(f"_DataGen{DataGenerator.n_objs}")
def __len__(self):
"""Number of batch in the Sequence.
Returns:
The number of batches in the Sequence.
"""
return len(self.df) // self.batch_size
def __is_valid(self, unique: (np.ndarray, np.ndarray)) -> bool:
"""Validates if the extracted patch by: 1. Dropping patches that show, only blank slides and not the tissue.
2. Dropping patches that have less than `self.thresh` % tissue in it.
Args:
unique: tuple[np.ndarray, np.ndarray]. tuple[0]: np.ndarray of unique element in the extracted patch.
tuple[1]: np.ndarray of frequency of elements in tuple[0].
Returns:
True: If, both the filtering conditions are satisfied.
False: Otherwise.
"""
if len(unique[0]) > 1 and unique[1][0] / (self.patch_size[0] * self.patch_size[1]) < (1 - self.thresh):
return True
return False
def __get_next_patch(self, image_id: str):
"""Generates subsequent patches and their corresponding label from the Whole Slide Image (WSI) with id = `image_id`
Args:
image_id: str. Name of the slide to extract the patches from.
Yields:
img_data: np.ndarray. Extracted patch from the original image.
label: int. 0 if the patch is benign (non-cancerous).
1 if the patch is malignant (cancerous).
"""
mask_path = f'{self.mask_dir}/{image_id}_mask.tiff'
image_path = f'{self.img_dir}/{image_id}.tiff'
data_center = self.indexed_df.loc[image_id].data_provider
mask = openslide.OpenSlide(mask_path)
img = openslide.OpenSlide(image_path)
if self.down_sampling_scale is None:
self.down_sampling_scale = {level: int(scaling_factor) for level, scaling_factor in enumerate(mask.level_downsamples)}
width, height = mask.level_dimensions[self.level][0] // self.patch_size[0], mask.level_dimensions[self.level][1] // self.patch_size[1]
scaling_factor = self.down_sampling_scale[self.level]
if self.write_to_disk:
benign_path = f'{DataGenerator._benign_dir}/{self.save_prefix}{image_id}'
malignant_path = f'{DataGenerator._malignant_dir}/{self.save_prefix}{image_id}'
for yidx in range(height):
for xidx in range(width):
mask_data = mask.read_region((xidx * self.patch_size[0] * scaling_factor, yidx * self.patch_size[1] * scaling_factor),
self.level, self.patch_size)
mask_data = mask_data.convert('RGB')
mask_data = np.asarray(mask_data, np.float32)
mask_data = mask_data[:, :, 0]
if data_center == 'radboud':
mask_data[mask_data == 2.0] = 1.0
unique = np.unique(mask_data, return_counts=True)
if self.__is_valid(unique):
img_data = img.read_region((xidx * self.patch_size[0] * scaling_factor, yidx * self.patch_size[1] * scaling_factor),
self.level, self.patch_size)
img_data = img_data.convert('RGB')
label = 0 if unique[0][-1] < 2 else 1
if self.write_to_disk:
img_data = np.asarray(img_data, np.uint8)
patch_img = Image.fromarray(img_data)
final_path = ""
if label is 0:
final_path = f'{benign_path}_{yidx}_{xidx}{self.save_format}'
else:
final_path = f'{malignant_path}_{yidx}_{xidx}{self.save_format}'
threading.Thread(target=DataGenerator.__thread_write_to_disk, args=(patch_img, final_path)).start()
else:
img_data = np.asarray(img_data, np.float32) / 255.0
yield img_data, label
else:
continue
img.close()
mask.close()
@classmethod
def __thread_write_to_disk(cls, patch_img: Image, final_path: str):
"""Writes extracted patch to the disk with a separate thread.
Args:
patch_img: PIL.Image. Image object from array of the extracted patch
final_path: str. Filepath/name_of_the_extracted_patch.
Follows naming convention:
`pre-fix/class/yidx/xidx/save_format`
"""
patch_img.save(final_path)
def __get_patches(self, image_id: str) -> (tf.Tensor, tf.Tensor):
"""Calls the generator function to get the next patch.
Args:
image_id: str. Name of the slide to extract the patches from.
Returns:
patches: tf.Tensor. Valid patches from an image.
labels: tf.Tensor. labels for extracted patches.
"""
next_patch = self.__get_next_patch(image_id)
patches, labels = [], []
for patch, label in next_patch:
patches.append(patch)
labels.append(label)
return tf.stack(patches), tf.stack(labels)
def __getitem__(self, idx):
"""Gets batch at position `index`.
Args:
index: position of the batch in the Sequence.
Returns:
A batch.
"""
batch_indices = self.indices[idx * self.batch_size: (idx + 1) * self.batch_size]
batch_image_ids = self.df['image_id'].iloc[batch_indices].values
batch_training_data = [self.__get_patches(image_id) for image_id in batch_image_ids]
if not self.write_to_disk:
batch_patches = [data[0] for data in batch_training_data]
batch_labels = [data[1] for data in batch_training_data]
batch_patches = tf.concat(batch_patches, axis=0)
batch_labels = tf.concat(batch_labels, axis=0)
batch_labels = ku.to_categorical(batch_labels, num_classes=2, dtype='int32')
return batch_patches, batch_labels
def fn_write_to_disk(self, batch_size: int) -> tf.keras.preprocessing.image.ImageDataGenerator:
"""Iterates over an instance of class `DataGenerator` for writing extracted patches to the disk.
Args:
batch_size: int. Batch size for reading extracted patches from the disk.
Returns:
datagen: tf.keras.preprocessing.image.ImageDataGenerator. ImageDataGenerator instance pointing towards
the directory which contains extracted patches.
"""
batches = len(self)
for count, _ in enumerate(self):
print(f'{count + 1} / {batches} batches written to the disk. ({(count + 1) * self.batch_size} images written to the disk.)')
print("Disk Write-up complete.")
if batch_size is not None:
self.batch_size = batch_size
datagenerator = ImageDataGenerator(rescale=1./255., dtype='float32')
datagen = datagenerator.flow_from_directory(DataGenerator._store_dir,
target_size=self.patch_size, batch_size=self.batch_size)
return datagen
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment