Skip to content

Instantly share code, notes, and snippets.

@thunterdb
Last active December 8, 2017 23:12
Show Gist options
  • Save thunterdb/d5f86c79457eea0f1021117ea4bce0ba to your computer and use it in GitHub Desktop.
Save thunterdb/d5f86c79457eea0f1021117ea4bce0ba to your computer and use it in GitHub Desktop.
tensorflow reading issue with S3
import tensorflow as tf
data_dir = "s3://databricks-public-datasets/mnist"
import os
from datetime import datetime
def curr_time():
return datetime.now().strftime("%H:%M:%S %D")
# Dataset API code from https://www.tensorflow.org/versions/r1.4/api_docs/python/tf/contrib/data/Dataset#shard
def test_s3_read_singlemachine(worker_index, num_workers, data_dir, num_readers, shuffle_buffer_size, batch_size, num_prefetch_batches, num_steps):
os.environ["S3_REGION"] = 'us-west-2'
tf_record_pattern = os.path.join(data_dir, 'train-*')
print("Reading training data from files matching glob pattern %s"%tf_record_pattern)
d = tf.data.Dataset.list_files(tf_record_pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat()
d = d.shuffle(shuffle_buffer_size)
d = d.repeat()
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=num_readers, block_length=1)
d = d.batch(batch_size)
d = d.prefetch(num_prefetch_batches)
iterator = d.make_initializable_iterator()
it_op = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in xrange(num_steps):
res = sess.run(it_op)
if i % 30 == 0:
print("Finished step %s, time: %s, result: %s"%(i, curr_time(), len(res)))
test_s3_read_singlemachine(worker_index=0, num_workers=2, data_dir=data_dir, num_readers=2, shuffle_buffer_size=32, batch_size=32, num_prefetch_batches=1, num_steps=int(1e7))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment