Skip to content

Instantly share code, notes, and snippets.

@mbsariyildiz
Created November 18, 2017 18:46
Show Gist options
  • Save mbsariyildiz/f4fe854c93bd4cad2bdae94b45fd0d3a to your computer and use it in GitHub Desktop.
Save mbsariyildiz/f4fe854c93bd4cad2bdae94b45fd0d3a to your computer and use it in GitHub Desktop.
Simple example of using tf.data.Dataset to create a data input pipeline from RAM to GPU.
import numpy as np
import tensorflow as tf
print 'tf_version: ', tf.__version__ # it is 1.4.0 right now
np.set_printoptions(linewidth=150, precision=3, suppress=True)
M = 10
d = 2
# samples
X = tf.constant(np.random.randn(M, d), 'float32')
# ids of samples, say each sample have different id
Y = tf.constant(np.expand_dims(np.random.permutation(M), 1), 'float32')
dset_items = (X,Y)
# first dimensions must match
# also they should be at least 2 rank
first_dims = [item.shape.as_list()[0] for item in dset_items]
assert np.all(np.equal(first_dims, first_dims[0]))
batch_size = 2
n_epochs = 5
dset = tf.data.Dataset.from_tensor_slices(dset_items)
dset = dset.shuffle(M)
dset = dset.repeat(n_epochs)
dset = dset.batch(batch_size)
dset = dset.prefetch(2)
dset_iterator = dset.make_initializable_iterator()
next_batch = dset_iterator.get_next()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(dset_iterator.initializer)
occurrence = np.zeros([M], 'int32')
it = 0
while True:
try:
it += 1
xb, yb = sess.run(next_batch)
occurrence[np.int32(yb)] += 1
print '%03d, x:%s, y:%s ' % (it, str(xb.ravel()), str(yb.ravel()))
except tf.errors.OutOfRangeError:
print 'end of dataset'
break
sess.close()
print 'occurrence array:', occurrence # all entries should be M, indicating each sample is fetched M times
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment