Skip to content

Instantly share code, notes, and snippets.

@prerakmody
Last active December 26, 2020 16:29
Show Gist options
  • Save prerakmody/a1db11f9b72b1f31f5c540ff76b76380 to your computer and use it in GitHub Desktop.
Save prerakmody/a1db11f9b72b1f31f5c540ff76b76380 to your computer and use it in GitHub Desktop.
Medical Dataloader in TensorFlow
class HaNMICCAI2015Dataset:
def __init__():
# initialize map, filter parameters
def generator():
# define generator, map and filter functions
batchsize = None # user defined value
prefetch_buffer = None # user defined value
shuffle_buffer = None # user defined value
epochs = None # user defined value
dataset = HaNMICCAI2015Dataset()
dataset_generator = dataset.generator().batch(batchsize).prefetch(prefetch_buffer).shuffle(shuffle_buffer)
for _ in range(epochs):
for (X,Y,meta1,meta2) in dataset_generator.repeat(1):
print (' - ', X.shape, Y.shape, meta1.numpy())
import tensorflow as tf
class HaNMICCAI2015Dataset:
def generator(self):
dataset = tf.data.Dataset.from_generator(self._generator3D
, output_types=(tf.int16, tf.uint8, tf.int32, tf.string)
, args=())
return dataset
def _generator3D(self)
for each in data:
yield each
import tensorflow as tf
class Trainer:
def __init__(self, params):
self.params = params
self.dataloader = params['dataloader']
self.model = params['model']
@tf.function
def _loss(self, Y, Y_predict):
# do some processing
return loss
@tf.function
def _train_step(self, X, Y):
model = self.params['model']
optimizer = self.params['optimizer']
Y_predict = model(X)
loss = self._loss(Y,Y_predict)
vars = model.trainable_variables
gradients = tape.gradient(loss, vars) # dL/dW
optimizer.apply_gradients(zip(gradients, vars))
def train(self):
for epoch in self.params['epochs']:
for X,Y in self.params['dataloader'].repeat(1):
self._train_step(X,Y)
import tensorflow as tf
class HaNMICCAI2015Dataset:
def generator(self):
dataset = tf.data.Dataset.from_generator(self._generator3D
, output_types=(tf.int16, tf.uint8, tf.int32, tf.string)
, args=())
dataset = dataset.map(self._get_data_3D , num_parallel_calls=1 , deterministic=False)
dataset = dataset.filter(self._filter3D)
return dataset
def _generator3D(self):
pass
@tf.function
def _get_data_3D(self, vol_img, vol_mask, meta1, meta2):
# do some processing
return (vol_img, vol_mask, meta1, meta2)
@tf.function
def _filter3D(self, vol_img, vol_mask, meta1, meta2):
return True # or False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment