Skip to content

Instantly share code, notes, and snippets.

@pangyuteng
Last active January 22, 2019 05:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pangyuteng/ca5cb07fe383ebe59b521c832f2e2918 to your computer and use it in GitHub Desktop.
Save pangyuteng/ca5cb07fe383ebe59b521c832f2e2918 to your computer and use it in GitHub Desktop.
profiling loading of multiple tfrecords versus one tfrecords
import sys,os
import tensorflow as tf
import numpy as np
from tensorflow.python.estimator.model_fn import ModeKeys as Modes
tf.logging.set_verbosity(tf.logging.INFO)
w = 512
h = 512
d = 300
c = 11
c_y = 1
one_hot_dim = 4
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
datatype = sys.argv[1]
if datatype == 'preformatted':
def convert_to(file_path):
print('Writing', file_path)
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
writer = tf.python_io.TFRecordWriter(file_path,options=options)
for r in range(10):
img = np.random.randint(0,1000,(w,h,c)).astype(np.int16)
label = np.random.randint(0,10,(w,h,c_y)).astype(np.int16)
print(img.shape,label.shape)
x,y,z = img.shape
image_raw = img.tostring()
label_raw = label.tostring()
#https://github.com/tensorflow/ecosystem/issues/61
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(x),
'width': _int64_feature(y),
'depth': _int64_feature(z),
'label_raw': _bytes_feature(label_raw),
'image_raw': _bytes_feature(image_raw),
}))
writer.write(example.SerializeToString())
writer.close()
for r in range(1):
convert_to(str(r)+'.tfrecords')
file_list = [str(x)+'.tfrecords' for x in range(1)]
elif datatype == 'multiple_tfrecords': # 96.28 seconds.
def convert_to(file_path):
rand = int(np.random.randint(-5,10,1))
#img = np.zeros((w,h+rand,c)).astype(np.int16)
img = np.random.randint(0,1000,(w,h+rand,d+rand)).astype(np.int16)
#label = np.zeros((w,h+rand,c)).astype(np.int16)
label = np.random.randint(0,10,(w,h+rand,d+rand)).astype(np.int16)
print(img.shape,label.shape)
x,y,z = img.shape
image_raw = img.tostring()
label_raw = label.tostring()
#https://github.com/tensorflow/ecosystem/issues/61
print('Writing', file_path)
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
writer = tf.python_io.TFRecordWriter(file_path,options=options)
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(x),
'width': _int64_feature(y),
'depth': _int64_feature(z),
'label_raw': _bytes_feature(label_raw),
'image_raw': _bytes_feature(image_raw),
}))
writer.write(example.SerializeToString())
writer.close()
for r in range(10):
convert_to(str(r)+'.tfrecords')
file_list = [str(x)+'.tfrecords' for x in range(10)]
elif datatype == 'single_tfrecords': # 96.28 seconds.
def convert_to(file_path):
print('Writing', file_path)
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
writer = tf.python_io.TFRecordWriter(file_path,options=options)
for r in range(10):
rand = int(np.random.randint(-5,10,1))
#img = np.zeros((w,h+rand,c)).astype(np.int16)
img = np.random.randint(0,1000,(w,h+rand,d+rand)).astype(np.int16)
#label = np.zeros((w,h+rand,c)).astype(np.int16)
label = np.random.randint(0,10,(w,h+rand,d+rand)).astype(np.int16)
print(img.shape,label.shape)
x,y,z = img.shape
image_raw = img.tostring()
label_raw = label.tostring()
#https://github.com/tensorflow/ecosystem/issues/61
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(x),
'width': _int64_feature(y),
'depth': _int64_feature(z),
'label_raw': _bytes_feature(label_raw),
'image_raw': _bytes_feature(image_raw),
}))
writer.write(example.SerializeToString())
writer.close()
for r in range(1):
convert_to(str(r)+'.tfrecords')
file_list = [str(x)+'.tfrecords' for x in range(1)]
else:
raise NotImplementedError()
print(file_list)
print('-------------')
def _preformatted_parse_(serialized_example):
features={
'image_raw': tf.FixedLenFeature([],tf.string),
'label_raw': tf.FixedLenFeature([],tf.string),
}
example = tf.parse_single_example(serialized_example,features)
image = tf.decode_raw(example['image_raw'], tf.int16)
image = tf.reshape(image, [w,h,c])
image = tf.cast(image,tf.float32)
label = tf.decode_raw(example['label_raw'], tf.int16)
label = tf.reshape(label, [w,h,c_y])
return image, label
def _parse_(serialized_example):
features={
'image_raw': tf.FixedLenFeature([],tf.string),
'label_raw': tf.FixedLenFeature([],tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
}
example = tf.parse_single_example(serialized_example,features)
height = tf.cast(example['height'],tf.int64)
width = tf.cast(example['width'],tf.int64)
depth = tf.cast(example['depth'],tf.int64)
image_shape = tf.stack([height, width, depth])
print('ok1')
image = tf.decode_raw(example['image_raw'], tf.int16)
image = tf.reshape(image, image_shape)
label = tf.decode_raw(example['label_raw'], tf.int16)
label = tf.reshape(label, image_shape)
print(image.get_shape(),'55')
image = tf.transpose(image,[1,2,0])
label = tf.transpose(label,[1,2,0])
print(image.get_shape(),'44')
print('ok2')
image = tf.image.resize_image_with_crop_or_pad(image,w,h)
label = tf.image.resize_image_with_crop_or_pad(label,w,h)
print(image.get_shape(),'33')
print('ok3')
image = tf.slice(image,[0,0,0],[w,w,c])
label = tf.slice(label,[0,0,0],[w,w,c_y])
print(image.get_shape(),'22')
print('ok4')
image = tf.reshape(image,[w,w,c])
image = tf.cast(image,tf.float32)
label = tf.reshape(label,[w,w,c_y])
print(image.get_shape(),'11')
return image, label
print(w*w*c,w*w*c_y)
batch_size = 2
print(file_list)
def train_input_fn(batch_size=batch_size, params=None):
return _input_fn(file_list, batch_size=batch_size)
def _input_fn(abs_path_file_list, batch_size=batch_size):
tfrecord_dataset = tf.data.TFRecordDataset(filenames=abs_path_file_list,
compression_type='GZIP')
# buffer_size=batch_size*10)
buffer_size=5
random_seed=42
if datatype == 'preformatted':
myparse = _preformatted_parse_
else:
myparse = _parse_
tfrecord_dataset = tfrecord_dataset.map(lambda x:myparse(x)).shuffle(
buffer_size,seed=random_seed,reshuffle_each_iteration=True).repeat().batch(batch_size)
tfrecord_iterator = tfrecord_dataset.make_one_shot_iterator()
return tfrecord_iterator.get_next()
sess = tf.InteractiveSession()
#tf.global_variables_initializer().run(s)
batch_size = 2
tf_images,tf_labels = train_input_fn(batch_size=batch_size)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord,sess=sess)
for _ in range(9):
img, lbl = sess.run([tf_images, tf_labels])
print(img.shape,lbl.shape)
coord.request_stop()
coord.join(threads)
def model_fn(features, labels, mode, params):
if mode in (Modes.PREDICT, Modes.EVAL):
pass
if mode in (Modes.TRAIN):
pass
szx = [-1,w,h,c]
szy = [-1,w,h,c_y]
x = tf.reshape(features, szx,name='input_layer')
y = tf.reshape(labels, szy,name='input_layer')
y_hat = tf.layers.conv2d(x, c_y, [3,3], padding='same')
y = tf.cast(y,tf.float32)
total_loss = tf.reduce_mean(tf.abs(y-y_hat))
global_step = tf.train.get_or_create_global_step()
lr=0.01
opt = tf.train.RMSPropOptimizer(learning_rate=lr)
t_op = opt.minimize(total_loss,global_step=global_step)
train_op = tf.group(t_op,)
output_layer = y_hat
expected_output = y
loss = total_loss
if mode in (Modes.PREDICT, Modes.EVAL):
probabilities = output_layer
predictions = tf.reshape(output_layer, [-1])
if mode in (Modes.TRAIN, Modes.EVAL):
pass
if mode == Modes.PREDICT:
predictions = {
'probabilities': probabilities
}
export_outputs = {
SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)
}
return tf.estimator.EstimatorSpec(
mode, predictions=predictions, export_outputs=export_outputs)
if mode == Modes.TRAIN:
print('************ TRAINING ************')
return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op)
if mode == Modes.EVAL:
eval_metric_ops = {
'mean_squared_error': tf.metrics.mean_squared_error(expected_output, output_layer)
}
return tf.estimator.EstimatorSpec(
mode, loss=loss, eval_metric_ops=eval_metric_ops)
total_steps = 10
params = {}
estimator = tf.estimator.Estimator(
model_fn,
model_dir='.',
params=params,)
import time
start =time.time()
output = estimator.train(
train_input_fn,
steps=total_steps,
)
end =time.time()
print(end-start)
@pangyuteng
Copy link
Author

pangyuteng commented Jan 22, 2019

The size of typically 3D medical image series are 512x512xN or 256x256xN. When creating datasets for training/testing a model/algorithm, the original image series are first sliced into the desired dimension, some times it is 512x512, sometimes 512x512xM where M < N, and sometime smaller patches are created 32x32x32. Ultimately you end up with multiple versions of training datasets that contain the same set of series, and pretty much bloats up hard disks and leaving admins not knowing what can be deleted. Above script attempts to investigate if individual series can be saved as single tfrecords (multiple_tfrecords) or one giant file (single_tfrecords), and then processed to the desired dimension while training. Training time using these tfrecords are compared against that using the conventional training dataset, which is preformatted (preformatted). Result is listed below.

Training time with pre formatted single tfrecord, multiple tfrecords and single tfrecord are 1.97, 96.28 and 93.32 seconds, respectively.

python test_train.py preformatted
python test_train.py multiple_tfrecords
python test_train.py single_tfrecords

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment