Skip to content

Instantly share code, notes, and snippets.

@ReemRashwan
Last active December 22, 2023 17:59
Show Gist options
  • Save ReemRashwan/8c92086d3104d01978a16e05ca93a165 to your computer and use it in GitHub Desktop.
Save ReemRashwan/8c92086d3104d01978a16e05ca93a165 to your computer and use it in GitHub Desktop.
Keras Dicom Images Data Generator and Augmenter from Dataframes (Benefits from ImageDataGenerator).
import numpy as np
import pandas as pd
import pydicom
import cv2
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras_preprocessing.image.dataframe_iterator import DataFrameIterator
# tested on tf 2.1
class DCMDataFrameIterator(DataFrameIterator):
def __init__(self, *arg, **kwargs):
self.white_list_formats = ('dcm')
super(DCMDataFrameIterator, self).__init__(*arg, **kwargs)
self.dataframe = kwargs['dataframe']
self.x = self.dataframe[kwargs['x_col']]
self.y = self.dataframe[kwargs['y_col']]
self.color_mode = kwargs['color_mode']
self.target_size = kwargs['target_size']
def _get_batches_of_transformed_samples(self, indices_array):
# get batch of images
batch_x = np.array([self.read_dcm_as_array(dcm_path, self.target_size, color_mode=self.color_mode)
for dcm_path in self.x.iloc[indices_array]])
batch_y = np.array(self.y.iloc[indices_array].astype(np.uint8)) # astype because y was passed as str
# transform images
if self.image_data_generator is not None:
for i, (x, y) in enumerate(zip(batch_x, batch_y)):
transform_params = self.image_data_generator.get_random_transform(x.shape)
batch_x[i] = self.image_data_generator.apply_transform(x, transform_params)
# you can change y here as well, eg: in semantic segmentation you want to transform masks as well
# using the same image_data_generator transformations.
return batch_x, batch_y
@staticmethod
def read_dcm_as_array(dcm_path, target_size=(256, 256), color_mode='rgb'):
image_array = pydicom.dcmread(dcm_path).pixel_array
image_array = cv2.resize(image_array, target_size, interpolation=cv2.INTER_NEAREST) #this returns a 2d array
image_array = np.expand_dims(image_array, -1)
if color_mode == 'rgb':
image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB)
return image_array
# read data
# Assuming it has two cols:
# image_path: path to each image with its extension
# target: labels (here it is 0s and 1s) -> binary classification
df = pd.read_csv("yourDfPath.csv", dtype=str)
# split for testing
train_df, test_df = train_test_split(df, test_size=0.2)
# augmentation parameters
# you can use preprocessing_function instead of rescale in all generators
# if you are using a pretrained network
train_augmentation_parameters = dict(
rescale=1.0/255.0,
rotation_range=10,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest',
brightness_range = [0.8, 1.2],
validation_split = 0.2
)
valid_augmentation_parameters = dict(
rescale=1.0/255.0,
validation_split = 0.2
)
test_augmentation_parameters = dict(
rescale=1.0/255.0
)
# training parameters
BATCH_SIZE = 32
CLASS_MODE = 'binary'
COLOR_MODE = 'grayscale'
TARGET_SIZE = (300, 300)
EPOCHS = 10
SEED = 1337
train_consts = {
'seed': SEED,
'batch_size': BATCH_SIZE,
'class_mode': CLASS_MODE,
'color_mode': COLOR_MODE,
'target_size': TARGET_SIZE,
'subset': 'training'
}
valid_consts = {
'seed': SEED,
'batch_size': BATCH_SIZE,
'class_mode': CLASS_MODE,
'color_mode': COLOR_MODE,
'target_size': TARGET_SIZE,
'subset': 'validation'
}
test_consts = {
'batch_size': 1, # should be 1 in testing
'class_mode': CLASS_MODE,
'color_mode': COLOR_MODE,
'target_size': TARGET_SIZE, # resize input images
'shuffle': False
}
# Using the training phase generators
train_augmenter = ImageDataGenerator(**train_augmentation_parameters)
valid_augmenter = ImageDataGenerator(**valid_augmentation_parameters)
train_generator = DCMDataFrameIterator(dataframe=train_df,
x_col='image_path',
y_col='target',
image_data_generator=train_augmenter,
**train_consts)
valid_generator = DCMDataFrameIterator(dataframe=train_df,
x_col='image_path',
y_col='target',
image_data_generator=valid_augmenter,
**valid_consts)
# define model architecture like how you normally do
model = ...
# training
history = model.fit_generator(
generator=train_generator,
steps_per_epoch=len(train_generator),
epochs=EPOCHS,
validation_data=valid_generator,
validation_steps=len(valid_generator)
)
# Using the testing generator to evaluate the model after training
test_augmenter = ImageDataGenerator(**test_augmentation_parameters)
test_generator = DCMDataFrameIterator(dataframe=test_df,
x_col='image_path',
y_col='target',
image_data_generator=test_augmenter,
**test_consts)
test_loss, test_accuracy = model.evaluate(test_generator, steps=len(test_generator))
@krekiehn
Copy link

krekiehn commented Nov 4, 2021

Hi! I try to use DICOM with uint16 bit with your DICOM Dataloader. Thanks for sharing it!
the rescaling don't work properly. the value output in the batch doesn't change if I change the rescale value from 0 to 1/255 or 1/(2**16-1).

The problem is in _get_batches_of_transformed_samples. there is no link to the "image_data_generator.standardize" method.
after the for loop I add:

batch_x = self.image_data_generator.standardize(batch_x.astype('float64'))

now it works for me

@ReemRashwan
Copy link
Author

Hi! I try to use DICOM with uint16 bit with your DICOM Dataloader. Thanks for sharing it! the rescaling don't work properly. the value output in the batch doesn't change if I change the rescale value from 0 to 1/255 or 1/(2**16-1).

The problem is in _get_batches_of_transformed_samples. there is no link to the "image_data_generator.standardize" method. after the for loop I add:

batch_x = self.image_data_generator.standardize(batch_x.astype('float64'))

now it works for me

You are welcome! and thanks for the note.

@osbm
Copy link

osbm commented Jun 11, 2022

both keras_preprocessing and tensorflow.keras.preprocessing deprecated.

@ReemRashwan
Copy link
Author

both keras_preprocessing and tensorflow.keras.preprocessing deprecated.

Thanks for the note, this may be a reference for future contributors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment