Skip to content

Instantly share code, notes, and snippets.

@nevers
Last active February 7, 2019 21:27
Show Gist options
  • Save nevers/45a96a10c5686d0008932187e4283ee3 to your computer and use it in GitHub Desktop.
Save nevers/45a96a10c5686d0008932187e4283ee3 to your computer and use it in GitHub Desktop.
A simple speed test comparing reading data through tfrecords datasets versus textline datasets
#!/usr/bin/env python3
import os
import time
import multiprocessing as mp
import tensorflow as tf
TFRECORDS_PATH = os.path.expanduser("~/.datasets/infuse-dl-dataset-v0.0.8/train.tfrecords")
CSV_PATH = os.path.expanduser("~/.datasets/infuse-dl-dataset-v0.0.8/train.csv")
BATCH_SIZE = 8
ORIG_IMG_H, ORIG_IMG_W = 406, 528
def main():
with tf.Session() as sess:
read(sess, threads=None)
read(sess, threads=mp.cpu_count())
def read(sess, threads):
if threads:
print("* Parallel read test")
else:
print("* Sequential read")
tfrecords_ds = tfrecords_dataset(TFRECORDS_PATH, threads).batch(BATCH_SIZE)
tfrecords_next = tfrecords_ds.make_one_shot_iterator().get_next()
textline_ds = textline_dataset(CSV_PATH, threads).batch(BATCH_SIZE)
textline_next = textline_ds.make_one_shot_iterator().get_next()
print("read textline dataset")
count, textline_duration = time_out_of_range(sess, textline_next)
print(f"textline dataset read {count} images in {textline_duration:.2f}s")
print("read tfrecords dataset")
count, tfrecords_duration = time_out_of_range(sess, tfrecords_next)
print(f"tfrecords dataset read {count} images in {tfrecords_duration:.2f}s")
print(f"difference: {tfrecords_duration - textline_duration:.2f}s")
def tfrecords_dataset(path, threads):
def _parse_tfrecord(record):
features = {
'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'image': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'name': tf.FixedLenFeature([], tf.string)
}
example = tf.parse_single_example(record, features)
label = tf.to_float(example["label"])
image = tf.to_float(example["image"])
image = tf.reshape(image, (406, 528, 1))
name = example["name"]
return {"label": label, "image": image, "name": name}
ds = tf.data.TFRecordDataset(path)
return ds.map(_parse_tfrecord, threads)
def textline_dataset(path, threads):
def _parse_line(l):
filename, x, y = tf.decode_csv(l, record_defaults=[[''], [0.], [0.]])
parent_dir = os.path.abspath(os.path.join(path, os.pardir))
file_path = f"{parent_dir}/train/" + filename
image = tf.image.decode_png(tf.read_file(file_path), channels=1)
return {"label": tf.stack([x, y]), "image": image, "name": filename}
return tf.data.TextLineDataset(path).map(_parse_line, threads)
def time_out_of_range(sess, op):
start_time = time.time()
count = 0
while True:
try:
res = sess.run(op)
count += len(res["name"])
except tf.errors.OutOfRangeError:
break
stop_time = time.time()
return count, stop_time - start_time
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment