Skip to content

Instantly share code, notes, and snippets.

@huafengw
Created March 15, 2018 03:28
Show Gist options
  • Save huafengw/85d920c5692ab02fd958ae4036f17986 to your computer and use it in GitHub Desktop.
Save huafengw/85d920c5692ab02fd958ae4036f17986 to your computer and use it in GitHub Desktop.
TF batch and batch_join
import tensorflow as tf
tensor_list = [[1,2,3,4], [5,6,7,8],[9,10,11,12],[13,14,15,16],[17,18,19,20]]
tensor_list2 = [[[1,2,3,4]], [[5,6,7,8]],[[9,10,11,12]],[[13,14,15,16]],[[17,18,19,20]]]
with tf.Session() as sess:
x1 = tf.train.batch(tensor_list, batch_size=4, enqueue_many=False)
x2 = tf.train.batch(tensor_list, batch_size=4, enqueue_many=True)
y1 = tf.train.batch_join(tensor_list, batch_size=4, enqueue_many=False)
y2 = tf.train.batch_join(tensor_list2, batch_size=4, enqueue_many=True)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
print("x1 batch:"+"-"*10)
print(sess.run(x1))
print("x2 batch:"+"-"*10)
print(sess.run(x2))
print("y1 batch:"+"-"*10)
print(sess.run(y1))
print("y2 batch:"+"-"*10)
print(sess.run(y2))
print("-"*10)
coord.request_stop()
coord.join(threads)
x1 batch:----------
[array([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]), array([[5, 6, 7, 8],
[5, 6, 7, 8],
[5, 6, 7, 8],
[5, 6, 7, 8]]), array([[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12]]), array([[13, 14, 15, 16],
[13, 14, 15, 16],
[13, 14, 15, 16],
[13, 14, 15, 16]]), array([[17, 18, 19, 20],
[17, 18, 19, 20],
[17, 18, 19, 20],
[17, 18, 19, 20]])]
x2 batch:----------
[array([1, 2, 3, 4]), array([5, 6, 7, 8]), array([ 9, 10, 11, 12]), array([13, 1
4, 15, 16]), array([17, 18, 19, 20])]
y1 batch:----------
[array([ 1, 9, 17, 9]), array([ 2, 10, 18, 10]), array([ 3, 11, 19, 11]), arra
y([ 4, 12, 20, 12])]
y2 batch:----------
[5 6 7 8]
----------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment