Last active
February 7, 2019 21:27
-
-
Save nevers/45a96a10c5686d0008932187e4283ee3 to your computer and use it in GitHub Desktop.
A simple speed test comparing reading data through tfrecords datasets versus textline datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import time | |
import multiprocessing as mp | |
import tensorflow as tf | |
TFRECORDS_PATH = os.path.expanduser("~/.datasets/infuse-dl-dataset-v0.0.8/train.tfrecords") | |
CSV_PATH = os.path.expanduser("~/.datasets/infuse-dl-dataset-v0.0.8/train.csv") | |
BATCH_SIZE = 8 | |
ORIG_IMG_H, ORIG_IMG_W = 406, 528 | |
def main(): | |
with tf.Session() as sess: | |
read(sess, threads=None) | |
read(sess, threads=mp.cpu_count()) | |
def read(sess, threads): | |
if threads: | |
print("* Parallel read test") | |
else: | |
print("* Sequential read") | |
tfrecords_ds = tfrecords_dataset(TFRECORDS_PATH, threads).batch(BATCH_SIZE) | |
tfrecords_next = tfrecords_ds.make_one_shot_iterator().get_next() | |
textline_ds = textline_dataset(CSV_PATH, threads).batch(BATCH_SIZE) | |
textline_next = textline_ds.make_one_shot_iterator().get_next() | |
print("read textline dataset") | |
count, textline_duration = time_out_of_range(sess, textline_next) | |
print(f"textline dataset read {count} images in {textline_duration:.2f}s") | |
print("read tfrecords dataset") | |
count, tfrecords_duration = time_out_of_range(sess, tfrecords_next) | |
print(f"tfrecords dataset read {count} images in {tfrecords_duration:.2f}s") | |
print(f"difference: {tfrecords_duration - textline_duration:.2f}s") | |
def tfrecords_dataset(path, threads): | |
def _parse_tfrecord(record): | |
features = { | |
'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True), | |
'image': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True), | |
'name': tf.FixedLenFeature([], tf.string) | |
} | |
example = tf.parse_single_example(record, features) | |
label = tf.to_float(example["label"]) | |
image = tf.to_float(example["image"]) | |
image = tf.reshape(image, (406, 528, 1)) | |
name = example["name"] | |
return {"label": label, "image": image, "name": name} | |
ds = tf.data.TFRecordDataset(path) | |
return ds.map(_parse_tfrecord, threads) | |
def textline_dataset(path, threads): | |
def _parse_line(l): | |
filename, x, y = tf.decode_csv(l, record_defaults=[[''], [0.], [0.]]) | |
parent_dir = os.path.abspath(os.path.join(path, os.pardir)) | |
file_path = f"{parent_dir}/train/" + filename | |
image = tf.image.decode_png(tf.read_file(file_path), channels=1) | |
return {"label": tf.stack([x, y]), "image": image, "name": filename} | |
return tf.data.TextLineDataset(path).map(_parse_line, threads) | |
def time_out_of_range(sess, op): | |
start_time = time.time() | |
count = 0 | |
while True: | |
try: | |
res = sess.run(op) | |
count += len(res["name"]) | |
except tf.errors.OutOfRangeError: | |
break | |
stop_time = time.time() | |
return count, stop_time - start_time | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment