Last active
December 26, 2020 16:29
-
-
Save prerakmody/a1db11f9b72b1f31f5c540ff76b76380 to your computer and use it in GitHub Desktop.
Medical Dataloader in TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class HaNMICCAI2015Dataset: | |
def __init__(): | |
# initialize map, filter parameters | |
def generator(): | |
# define generator, map and filter functions | |
batchsize = None # user defined value | |
prefetch_buffer = None # user defined value | |
shuffle_buffer = None # user defined value | |
epochs = None # user defined value | |
dataset = HaNMICCAI2015Dataset() | |
dataset_generator = dataset.generator().batch(batchsize).prefetch(prefetch_buffer).shuffle(shuffle_buffer) | |
for _ in range(epochs): | |
for (X,Y,meta1,meta2) in dataset_generator.repeat(1): | |
print (' - ', X.shape, Y.shape, meta1.numpy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class HaNMICCAI2015Dataset: | |
def generator(self): | |
dataset = tf.data.Dataset.from_generator(self._generator3D | |
, output_types=(tf.int16, tf.uint8, tf.int32, tf.string) | |
, args=()) | |
return dataset | |
def _generator3D(self) | |
for each in data: | |
yield each |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class Trainer: | |
def __init__(self, params): | |
self.params = params | |
self.dataloader = params['dataloader'] | |
self.model = params['model'] | |
@tf.function | |
def _loss(self, Y, Y_predict): | |
# do some processing | |
return loss | |
@tf.function | |
def _train_step(self, X, Y): | |
model = self.params['model'] | |
optimizer = self.params['optimizer'] | |
Y_predict = model(X) | |
loss = self._loss(Y,Y_predict) | |
vars = model.trainable_variables | |
gradients = tape.gradient(loss, vars) # dL/dW | |
optimizer.apply_gradients(zip(gradients, vars)) | |
def train(self): | |
for epoch in self.params['epochs']: | |
for X,Y in self.params['dataloader'].repeat(1): | |
self._train_step(X,Y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class HaNMICCAI2015Dataset: | |
def generator(self): | |
dataset = tf.data.Dataset.from_generator(self._generator3D | |
, output_types=(tf.int16, tf.uint8, tf.int32, tf.string) | |
, args=()) | |
dataset = dataset.map(self._get_data_3D , num_parallel_calls=1 , deterministic=False) | |
dataset = dataset.filter(self._filter3D) | |
return dataset | |
def _generator3D(self): | |
pass | |
@tf.function | |
def _get_data_3D(self, vol_img, vol_mask, meta1, meta2): | |
# do some processing | |
return (vol_img, vol_mask, meta1, meta2) | |
@tf.function | |
def _filter3D(self, vol_img, vol_mask, meta1, meta2): | |
return True # or False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment