-
-
Save mayureshagashe2105/d06d7161a7ebe7c4a682ff44cde7dde6 to your computer and use it in GitHub Desktop.
Custom Data Generators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataGenerator(tf.keras.utils.Sequence): | |
"""Generates batches of valid patches of the images and infers the corresponding labels for those patches via given mask encoding. | |
Args: | |
img_dir: str. Directory path where all the training images are present. | |
mask_dir: str. Directory path where all the training masks are present. | |
df_images: pd.DataFrame. Pandas DataFrame object that contains information about each image in the image directory. | |
batch_size: int. Size of the batches of data. | |
patch_size: tuple. (x-length, y-length). Image will be divided into patches of the desired size. | |
``Example : (128, 128), Output will be of shape (-1, 128, 128, channels).`` | |
level: int. Should be in the range[0, 2]. Indicates from which resolution level the patch has to be extracted. | |
tissue_threshold: float. Should be in range[0, 1]. percentage of tissue region in a patch to consider that patch to be valid. | |
write_to_disk: bool. Default: False. If True, Writes extracted patches to a directory upon | |
iteration and returns a `tf.keras.preprocessing.image.ImageDataGenerator` | |
instance pointing to that directory. | |
save_format: str. Default: None. Extension of the file to store on the disk (should be written with the `.`). | |
``Example: '.png', '.jpg', etc.`` | |
save_prefix = str. Default: "". Prefix to attach to filenames before saving to the disk. | |
is_training: bool. Default: True. If true, images for each batch will be chosen at random. If False, images are taken in order. | |
Raises: | |
ValueError: If `level` is not in range [0, 2]. | |
ValueError: If `tissue_threshold` is not a fraction in range [0, 1]. | |
AssertionError: If `patch_size` is not a tupple with length = 2. | |
TypeError: If `save_format` is None while `write_to_disk` is True. | |
""" | |
# Class Attributes | |
_base_dir = '/content/patches_and_labels' | |
_store_dir = "" | |
_benign_dir = "" | |
_malignant_dir = "" | |
n_objs = 0 | |
def __init__(self, | |
img_dir: str, | |
mask_dir: str, | |
df_images: pd.DataFrame, | |
batch_size: int, | |
patch_size: tuple, | |
level: int, | |
tissue_threshold = 0.50, | |
write_to_disk = False, | |
save_format = None, | |
save_prefix = "", | |
is_training = True, | |
): | |
if 0 > level > 2: | |
raise ValueError(f"level should be in the range [0, 2].") | |
if 0 > tissue_threshold > 1: | |
raise ValueError(f"tissue_threshold should be in the range [0, 1].") | |
assert len(patch_size) == 2, "length of `patch_size` should be equal to 2." | |
if write_to_disk: | |
if save_format is None: | |
raise TypeError("save format should not be None") | |
if not os.path.exists(DataGenerator._base_dir + repr(self)): | |
os.mkdir(DataGenerator._base_dir + repr(self)) | |
if is_training: | |
DataGenerator._store_dir = f'{DataGenerator._base_dir + repr(self)}/train' | |
else: | |
DataGenerator._store_dir = f'{DataGenerator._base_dir + repr(self)}/validate' | |
if not os.path.exists(DataGenerator._store_dir): | |
os.mkdir(DataGenerator._store_dir) | |
DataGenerator._benign_dir = f'{DataGenerator._store_dir}/benign' | |
if not os.path.exists(DataGenerator._benign_dir): | |
os.mkdir(DataGenerator._benign_dir) | |
DataGenerator._malignant_dir = f'{DataGenerator._store_dir}/malignant' | |
if not os.path.exists(DataGenerator._malignant_dir): | |
os.mkdir(DataGenerator._malignant_dir) | |
self.img_dir = img_dir | |
self.mask_dir = mask_dir | |
self.df = df_images | |
self.patch_size = patch_size | |
self.batch_size = batch_size | |
self.level = level | |
self.thresh = tissue_threshold | |
self.training = is_training | |
self.write_to_disk = write_to_disk | |
self.save_format = save_format | |
self.save_prefix = save_prefix | |
self.down_sampling_scale = None | |
self.indices = range(len(self.df)) | |
self.indexed_df = self.df.copy() | |
self.indexed_df.set_index('image_id', inplace=True) | |
DataGenerator.n_objs += 1 | |
def __repr__(self): | |
"""Returns the string representation of any object. | |
Returns: | |
Custom string representation of an instance. | |
""" | |
return(f"_DataGen{DataGenerator.n_objs}") | |
def __len__(self): | |
"""Number of batch in the Sequence. | |
Returns: | |
The number of batches in the Sequence. | |
""" | |
return len(self.df) // self.batch_size | |
def __is_valid(self, unique: (np.ndarray, np.ndarray)) -> bool: | |
"""Validates if the extracted patch by: 1. Dropping patches that show, only blank slides and not the tissue. | |
2. Dropping patches that have less than `self.thresh` % tissue in it. | |
Args: | |
unique: tuple[np.ndarray, np.ndarray]. tuple[0]: np.ndarray of unique element in the extracted patch. | |
tuple[1]: np.ndarray of frequency of elements in tuple[0]. | |
Returns: | |
True: If, both the filtering conditions are satisfied. | |
False: Otherwise. | |
""" | |
if len(unique[0]) > 1 and unique[1][0] / (self.patch_size[0] * self.patch_size[1]) < (1 - self.thresh): | |
return True | |
return False | |
def __get_next_patch(self, image_id: str): | |
"""Generates subsequent patches and their corresponding label from the Whole Slide Image (WSI) with id = `image_id` | |
Args: | |
image_id: str. Name of the slide to extract the patches from. | |
Yields: | |
img_data: np.ndarray. Extracted patch from the original image. | |
label: int. 0 if the patch is benign (non-cancerous). | |
1 if the patch is malignant (cancerous). | |
""" | |
mask_path = f'{self.mask_dir}/{image_id}_mask.tiff' | |
image_path = f'{self.img_dir}/{image_id}.tiff' | |
data_center = self.indexed_df.loc[image_id].data_provider | |
mask = openslide.OpenSlide(mask_path) | |
img = openslide.OpenSlide(image_path) | |
if self.down_sampling_scale is None: | |
self.down_sampling_scale = {level: int(scaling_factor) for level, scaling_factor in enumerate(mask.level_downsamples)} | |
width, height = mask.level_dimensions[self.level][0] // self.patch_size[0], mask.level_dimensions[self.level][1] // self.patch_size[1] | |
scaling_factor = self.down_sampling_scale[self.level] | |
if self.write_to_disk: | |
benign_path = f'{DataGenerator._benign_dir}/{self.save_prefix}{image_id}' | |
malignant_path = f'{DataGenerator._malignant_dir}/{self.save_prefix}{image_id}' | |
for yidx in range(height): | |
for xidx in range(width): | |
mask_data = mask.read_region((xidx * self.patch_size[0] * scaling_factor, yidx * self.patch_size[1] * scaling_factor), | |
self.level, self.patch_size) | |
mask_data = mask_data.convert('RGB') | |
mask_data = np.asarray(mask_data, np.float32) | |
mask_data = mask_data[:, :, 0] | |
if data_center == 'radboud': | |
mask_data[mask_data == 2.0] = 1.0 | |
unique = np.unique(mask_data, return_counts=True) | |
if self.__is_valid(unique): | |
img_data = img.read_region((xidx * self.patch_size[0] * scaling_factor, yidx * self.patch_size[1] * scaling_factor), | |
self.level, self.patch_size) | |
img_data = img_data.convert('RGB') | |
label = 0 if unique[0][-1] < 2 else 1 | |
if self.write_to_disk: | |
img_data = np.asarray(img_data, np.uint8) | |
patch_img = Image.fromarray(img_data) | |
final_path = "" | |
if label is 0: | |
final_path = f'{benign_path}_{yidx}_{xidx}{self.save_format}' | |
else: | |
final_path = f'{malignant_path}_{yidx}_{xidx}{self.save_format}' | |
threading.Thread(target=DataGenerator.__thread_write_to_disk, args=(patch_img, final_path)).start() | |
else: | |
img_data = np.asarray(img_data, np.float32) / 255.0 | |
yield img_data, label | |
else: | |
continue | |
img.close() | |
mask.close() | |
@classmethod | |
def __thread_write_to_disk(cls, patch_img: Image, final_path: str): | |
"""Writes extracted patch to the disk with a separate thread. | |
Args: | |
patch_img: PIL.Image. Image object from array of the extracted patch | |
final_path: str. Filepath/name_of_the_extracted_patch. | |
Follows naming convention: | |
`pre-fix/class/yidx/xidx/save_format` | |
""" | |
patch_img.save(final_path) | |
def __get_patches(self, image_id: str) -> (tf.Tensor, tf.Tensor): | |
"""Calls the generator function to get the next patch. | |
Args: | |
image_id: str. Name of the slide to extract the patches from. | |
Returns: | |
patches: tf.Tensor. Valid patches from an image. | |
labels: tf.Tensor. labels for extracted patches. | |
""" | |
next_patch = self.__get_next_patch(image_id) | |
patches, labels = [], [] | |
for patch, label in next_patch: | |
patches.append(patch) | |
labels.append(label) | |
return tf.stack(patches), tf.stack(labels) | |
def __getitem__(self, idx): | |
"""Gets batch at position `index`. | |
Args: | |
index: position of the batch in the Sequence. | |
Returns: | |
A batch. | |
""" | |
batch_indices = self.indices[idx * self.batch_size: (idx + 1) * self.batch_size] | |
batch_image_ids = self.df['image_id'].iloc[batch_indices].values | |
batch_training_data = [self.__get_patches(image_id) for image_id in batch_image_ids] | |
if not self.write_to_disk: | |
batch_patches = [data[0] for data in batch_training_data] | |
batch_labels = [data[1] for data in batch_training_data] | |
batch_patches = tf.concat(batch_patches, axis=0) | |
batch_labels = tf.concat(batch_labels, axis=0) | |
batch_labels = ku.to_categorical(batch_labels, num_classes=2, dtype='int32') | |
return batch_patches, batch_labels | |
def fn_write_to_disk(self, batch_size: int) -> tf.keras.preprocessing.image.ImageDataGenerator: | |
"""Iterates over an instance of class `DataGenerator` for writing extracted patches to the disk. | |
Args: | |
batch_size: int. Batch size for reading extracted patches from the disk. | |
Returns: | |
datagen: tf.keras.preprocessing.image.ImageDataGenerator. ImageDataGenerator instance pointing towards | |
the directory which contains extracted patches. | |
""" | |
batches = len(self) | |
for count, _ in enumerate(self): | |
print(f'{count + 1} / {batches} batches written to the disk. ({(count + 1) * self.batch_size} images written to the disk.)') | |
print("Disk Write-up complete.") | |
if batch_size is not None: | |
self.batch_size = batch_size | |
datagenerator = ImageDataGenerator(rescale=1./255., dtype='float32') | |
datagen = datagenerator.flow_from_directory(DataGenerator._store_dir, | |
target_size=self.patch_size, batch_size=self.batch_size) | |
return datagen |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment