Skip to content

Instantly share code, notes, and snippets.

@darrengarvey
Created April 3, 2018 11:21
Show Gist options
  • Save darrengarvey/ff05fbe28ab2061c101fe64353b467ff to your computer and use it in GitHub Desktop.
Save darrengarvey/ff05fbe28ab2061c101fe64353b467ff to your computer and use it in GitHub Desktop.
Perf test for `tensorflow.data.Dataset.list_files()`
#!/usr/bin/env python
from __future__ import print_function
'''
Perf test for `tensorflow.data.Dataset.list_files()`.
At least on TF 1.7 all file names are loaded into memory before
training starts. This can be a bottleneck for slower disks, especially
with large datasets.
'''
import os
import shutil
import tempfile
import tensorflow as tf
import time
from functools import partial
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('dir', '/tmp/', 'Base directory to perf test')
tf.app.flags.DEFINE_integer('width', 10, 'Number of dirs at each level')
tf.app.flags.DEFINE_integer('depth', 2, 'Number of nested dirs')
def timeit(fn, msg, N=0):
start = time.time()
res = fn()
end = time.time()
runtime = (end - start) * 1000
msg = '{}: time: {:.2f} ms'.format(msg, runtime)
if N:
msg += ' ({:.2f} ms per iteration)'.format(runtime / N)
print(msg)
return res
def load_data():
if not os.path.exists(FLAGS.dir):
os.makedirs(FLAGS.dir)
base = tempfile.mkdtemp(prefix=FLAGS.dir)
print('saving files to dir: {}'.format(base))
start = time.time()
for i in range(FLAGS.width):
new_base = os.path.join(base, str(i), *[str(j) for j in range(FLAGS.depth - 1)])
if not os.path.exists(new_base):
os.makedirs(new_base)
f = os.path.join(new_base, 'stuff.txt')
open(f, 'w').close()
return base
def prep_data(base):
pattern = '{}/{}/*.txt'.format(base, os.path.join(*['**' for _ in range(FLAGS.depth)]))
dataset = tf.data.Dataset.list_files(pattern)
return dataset.make_one_shot_iterator().get_next()
def read_data(data, sess, N=1):
for _ in range(N):
sess.run(data)
def main(_):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # hide some uninteresting logs
base = timeit(load_data, 'load data')
data = timeit(partial(prep_data, base), 'prep data')
with tf.Session() as sess:
timeit(partial(read_data, data, sess), 'read first filename')
timeit(partial(read_data, data, sess), 'read second filename')
N = (FLAGS.width) - 2
timeit(partial(read_data, data, sess, N), 'read {} more filenames'.format(N), N)
shutil.rmtree(base)
if __name__ == '__main__':
tf.app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment