Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rahulremanan/a0f25a477195efe5cfdff697506998fc to your computer and use it in GitHub Desktop.
Save rahulremanan/a0f25a477195efe5cfdff697506998fc to your computer and use it in GitHub Desktop.
A DICOM reader sequence object using Tensorflow Keras Sequence API
class DicomGenerator(tf.keras.utils.Sequence):
def __init__(self, dicom_path, batch_size=1, dtype='float32',
shuffle=False, drop_remainder=False,
preserve_batch_size=False, **kwargs):
self.i = 0
self.batch_size = batch_size
self.dtype = dtype
self.n = len(dicom_path)
self.dicom_path = dicom_path
self.drop_remainder = drop_remainder
self.preserve_batch_size = preserve_batch_size
self.shuffle = shuffle
self.__dict__.update(**kwargs)
self.kwargs = kwargs
def __load__(self, dicom_filename):
_img = read_dicom(dicom_filename)
if _img.max() - _img.min() > 0: _img = (_img - _img.min()) / (_img.max() - _img.min())
_img = np.asarray(_img, dtype=getattr(np, self.dtype))
return _img
def __getitem__(self, index):
if (index + 1) * self.batch_size <= len(self.dicom_path):
_dicom_path_batch = self.dicom_path[index * self.batch_size:(index + 1) * self.batch_size]
elif self.drop_remainder:
raise StopIteration()
elif self.preserve_batch_size and len(self.dicom_path[index * self.batch_size:len(self.dicom_path)]) != self.batch_size:
_dicom_path_batch = self.dicom_path[len(self.dicom_path)-self.batch_size:len(self.dicom_path)]
else:
_dicom_path_batch = self.dicom_path[index * self.batch_size:len(self.dicom_path)]
_img_arr = list(map(lambda _dcm_pth: self.__load__(_dcm_pth), _dicom_path_batch))
_img_arr = np.array(_img_arr, dtype=getattr(np, self.dtype))
return _img_arr
def __iter__(self):
return self
def __next__(self):
if self.i*self.batch_size < len(self.dicom_path):
_img_arr = self.__getitem__(self.i)
self.i += 1
else: raise StopIteration()
_img_arr = tf.cast(_img_arr, dtype=getattr(tf, self.dtype))
return _img_arr
def __call__(self):
self.i = 0
return self
def on_epoch_end(self):
if self.shuffle: random.SystemRandom().shuffle(self.dicom_path)
def __len__(self):
return self.n // self.batch_size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment