Skip to content

Instantly share code, notes, and snippets.

@akiross
Created April 3, 2017 16:50
Show Gist options
  • Save akiross/23b6ae42812841bb79af4976a2525cf9 to your computer and use it in GitHub Desktop.
Save akiross/23b6ae42812841bb79af4976a2525cf9 to your computer and use it in GitHub Desktop.
This notebook aims to explain how queues are used in tensorflow, in a bit more practical way than the official docs. I developed and tested this with TF 1.0
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Understanding Queues (and their role) in TensorFlow\n",
"\n",
"I am trying to understand how queues can (and shall) be used in TF, but I am getting a bit frustrated by the docs... Info is superficial and sparse, and I had to research for a while to get my mind over it. I want to share this knowledge with you, so... Shall we start?"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"from threading import Thread\n",
"from time import sleep\n",
"from random import randint"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"## Queues\n",
"\n",
"Using basic queues in TF is not hard. Consider a queue like a variable that stores a number of objects. This number is limited by queue's capacity. When you created this \"variable\" and you initialized it, you can enqueue and dequeue data.\n",
"\n",
"What is important to understand is that queue and dequeue function calls will not perform the queue/dequeue immediately, but they return an enqueue/dequeue operation.\n",
"\n",
"So, they return a code that you must run to make it happen. Let's see an example."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n",
"2.0\n",
"3.0\n"
]
}
],
"source": [
"tf.reset_default_graph() # So we have a clean graph every cell\n",
"with tf.Session() as sess:\n",
" # Create a queue that can store at most 5 float elements\n",
" q = tf.FIFOQueue(capacity=5, dtypes='float')\n",
" \n",
" # Create an operation that enqueues some data into the queue\n",
" # This is NOT executed immediately\n",
" enq_data = q.enqueue_many([[1.0, 2.0, 3.0]])\n",
" \n",
" # Create an operation that returns the next element in queue\n",
" get_next = q.dequeue()\n",
"\n",
" # Create an operation to initialize variables\n",
" init = tf.group(\n",
" tf.global_variables_initializer(),\n",
" tf.local_variables_initializer())\n",
" \n",
" # Initialize the variables\n",
" init.run()\n",
" \n",
" # Actually enqueue the data\n",
" # If you don't run this operation, the queue will hang waiting\n",
" # for some data to be ready (in our case, it will deadlock)\n",
" enq_data.run()\n",
" \n",
" # Print elements in queue \n",
" print(get_next.eval())\n",
" print(get_next.eval())\n",
" print(get_next.eval())\n",
"\n",
" # If we execute the following, the code will hang because\n",
" # there are no other elements in the queue\n",
" # print(get_next.eval())"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"source": [
"This is all nice and easy, and once you understand and internalize that `queue()` and `dequeue()` are returning ops, you are set to go and you can use queue.\n",
"\n",
"Except that you cannot use TF with *just this knowledge*.\n",
"\n",
"In fact, TF contains a few utilities for managing threads and queue that are used in the rest of the framework, and since the framework heavily relies on some ways of accessing data, you have to understand few more concepts before using queues effectively.\n",
"\n",
"## Coordinator\n",
"\n",
"The next thing to know next is the `tf.train.Coordinator`. A coordinator is basically a syncing primitive: instead of using mutex, locks, etc, you can use a coordinator.\n",
"\n",
"As an example, look at this code that runs a couple of threads and coordinates them."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PR1: 0\n",
"PR2: 0\n",
"PR1: 1\n",
"PR1: 2\n",
"PR2: 1\n",
"PR1: 3\n",
"PR1: 4\n",
"PR2: 2\n",
"PR1: 5\n",
"PR1: 6\n",
"PR2: 3\n",
"PR1: 7\n",
"Done\n"
]
}
],
"source": [
"def make_printer(N, end=7):\n",
" def printer(coord):\n",
" '''This function prints an incremental number every N second\n",
" and asks to stop when the end number is reached'''\n",
" # Counter\n",
" i = 0\n",
" # Sleep for a random amount of time, introducing some\n",
" # delay between prints of different threads\n",
" sleep(randint(0, 5) / 5)\n",
" # Check if we can continue looping\n",
" while not coord.should_stop():\n",
" print('PR{}: {}'.format(N, i))\n",
" i += 1\n",
" sleep(N)\n",
" if i > end:\n",
" coord.request_stop()\n",
" return printer\n",
"\n",
"# Create a coordinator\n",
"# This allows our threads to communicate\n",
"coord = tf.train.Coordinator()\n",
"\n",
"# Start two threads, one printing every sec, one printing every 2 secs\n",
"printers = [Thread(target=make_printer(n), args=(coord,)) for n in [1, 2]]\n",
"\n",
"for p in printers:\n",
" p.start() # Start the thread\n",
"\n",
"coord.join(printers) # Wait for threads to finish\n",
"print('Done')"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Besides possible problems when printing (threads writing at the same time), it's pretty clear what the codes do: Coordinator will keep a shared state which is used by threads to know when they have to stop.\n",
"\n",
"Also, after starting the threads, coordinator can wait for all of them to finish (remember to pass the threads to wait for as an argument to `join()`).\n",
"\n",
"Things gets now a bit more complicated: you can indeed roll your own threading code, using coordinator so that one (or more) threads can produce some data, pushing them in a queue, and one (or more) threads can read those data from the queue. Let's try"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Produced 0.0\n",
"Consumed 0.0\n",
"Produced 1.0\n",
"Produced 2.0\n",
"Produced 3.0\n",
"Consumed 1.0\n",
"Produced 4.0\n",
"Produced 5.0\n",
"Produced 6.0\n",
"Produced 7.0\n",
"Consumed 2.0\n",
"Produced 8.0\n",
"Consumed 3.0\n",
"Produced 9.0\n",
"Consumed 4.0\n",
"Produced 10.0\n",
"Consumed 5.0\n",
"Stopping production\n",
"Done\n"
]
}
],
"source": [
"tf.reset_default_graph() # So we have a clean graph every cell\n",
"\n",
"def producer(sess, coord, enq, x, end=10):\n",
" '''This function enqueues an incremental number every second\n",
" and asks to stop when the end number is reached'''\n",
" i = float(0)\n",
" while not coord.should_stop():\n",
" print('Produced', i)\n",
" # Run the enqueue operation\n",
" # We use the placeholder x to provide the value from outside the graph\n",
" enq.run(session=sess, feed_dict={x: i})\n",
" i += 1.0\n",
" sleep(0.5)\n",
" if i > end:\n",
" print('Stopping production')\n",
" coord.request_stop()\n",
"\n",
"def consumer(sess, coord, deq):\n",
" '''This function dequeues and print until possible'''\n",
" while not coord.should_stop():\n",
" print('Consumed', sess.run(deq))\n",
" sleep(randint(2, 5)) # Wait some time\n",
"\n",
"with tf.Session() as sess:\n",
" q = tf.FIFOQueue(5, 'float')\n",
" x = tf.placeholder('float')\n",
" \n",
" # Operation that enqueues some data\n",
" enq = q.enqueue(x)\n",
" deq = q.dequeue()\n",
"\n",
" coord = tf.train.Coordinator()\n",
"\n",
" tf.group(\n",
" tf.global_variables_initializer(),\n",
" tf.local_variables_initializer()\n",
" ).run()\n",
" \n",
" p = Thread(target=producer, args=(sess, coord, enq, x))\n",
" c = Thread(target=consumer, args=(sess, coord, deq))\n",
" \n",
" p.start()\n",
" c.start()\n",
"\n",
" coord.join([p, c]) # Wait for threads to finish\n",
" print('Done')"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"This should be rather clear, but let me explain it anyway.\n",
"\n",
"It is important to create the enqueue/dequeue operations once. In the `producer` function, I do not perform this operation at every cycle:\n",
"\n",
" queue.enqueue(i).run()\n",
"\n",
"the reason is that `queue.enqueue()` creates a new operation in the graph every time you call it, and if you perform this while running, TensorFlow will complain that you are trying to execute the enqueue operation from a different graph than the one containing the queue itself. Therefore, we first create an enqueue operation, `enq`, which depends on a placeholder. This allows us to feed new data every time we need to push something new in the queue.\n",
"\n",
"A similar argument goes for the dequeue operation, pre-defined and passed in the thread function.\n",
"\n",
"Once this is clear, the code should be easy to understand: we define the function to use operations in the graph, run the functions as separated threads, passing the queue and the operations as arguments, and wait them to complete.\n",
"\n",
"Note how, when producer gets to the end, it will request to stop and consumer will stop as well, even if there are still values in the queue.\n",
"\n",
"## QueueRunner\n",
"\n",
"Well, in TensorFlow the usage of threads to push data into queues is rather common. For this reason, there is another helper object called `QueueRunner`. Its job is to run threads that enqueue data onto a queue. You just have to build the queue and the ops, as we did earlier, and it will take care of the threads creation. Coordinator is still use to coordinate those threads.\n",
"\n",
"Let's code again the example above, but using `QueueRunner` this time."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Consumed 1.0\n",
"Consumed 2.0\n",
"Consumed 3.0\n",
"Consumed 4.0\n",
"Consumed 5.0\n",
"Consumed 6.0\n",
"Consumed 7.0\n",
"Consumed 8.0\n",
"Consumed 9.0\n",
"Consumed 10.0\n",
"Done\n"
]
}
],
"source": [
"tf.reset_default_graph() # So we have a clean graph every cell\n",
"\n",
"with tf.Session() as sess:\n",
" q = tf.FIFOQueue(5, 'float')\n",
"\n",
" # Instead of using a placeholder, we keep a variable that is\n",
" # incremented at every use\n",
" count = tf.Variable(0.0)\n",
" # This op increments the variable when executed\n",
" inc = tf.assign_add(count, 1)\n",
"\n",
" # Operation that enqueues some data\n",
" enq = q.enqueue(tf.add(inc, 0)) # Adding zero forces the creation of a new variable \n",
" deq = q.dequeue()\n",
"\n",
" coord = tf.train.Coordinator()\n",
"\n",
" tf.group(\n",
" tf.global_variables_initializer(),\n",
" tf.local_variables_initializer()\n",
" ).run()\n",
" \n",
" # This will create one threads that enqueue data\n",
" qr = tf.train.QueueRunner(q, [enq])\n",
"\n",
" # This will start the thread and return its handler\n",
" enq_threads = qr.create_threads(sess, coord=coord, start=True)\n",
" \n",
" # Run the consumer thread\n",
" while not coord.should_stop():\n",
" i = sess.run(deq)\n",
" print('Consumed', i)\n",
" sleep(0.5) # Wait some time\n",
" if i >= 10:\n",
" coord.request_stop()\n",
" \n",
" coord.join(enq_threads)\n",
" print('Done')"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"An interesting thing to note here is how zero is added to the incremented value: `q.enqueue(tf.add(inc, 0))`. This is because `enqueue` will enqueue a reference to a variable instead of the variable value. Adding 0 forces the creation of a new tensor, which is in turn enqueued.\n",
"\n",
"For more info, check these [interesting](http://stackoverflow.com/questions/40668712/enqueue-and-increment-variable-in-tensor-flow) [questions](http://stackoverflow.com/questions/41308515/force-copy-of-tensor-when-enqueuing) on stackoverflow.\n",
"\n",
"Well, the idea behind `QueueRunner` is clear: it will create a number of threads, each of wich will run the enqueueing operation in the specified session. If a coordinator is provided, it will be used.\n",
"\n",
"Good, everything is set, we now can get to the point.\n",
"\n",
"## How Queues are used in TensorFlow\n",
"\n",
"QueueRunner is used by some TF methods that use queues internally. You have to be careful about how threads are managed and how you enqueue data in those cases. It is not *hard to do*, but you shall know what to do. From the doc, it's not always very clear.\n",
"\n",
"Take, for example, the handy `tf.train.slice_input_producer`. The idea is that, given a tensor, this operation will slice it and return one element at time. In the [doc](https://www.tensorflow.org/versions/master/api_docs/python/tf/train/slice_input_producer), you can read\n",
"\n",
"> Produces a slice of each Tensor in tensor_list.\n",
"> Implemented using a Queue -- a QueueRunner for the Queue is added to the current Graph's QUEUE_RUNNER collection.\n",
"\n",
"My first question was: what the heck is a QUEUE_RUNNER collection? Well, put simply, it's just a way to group together graph nodes that share some proprierties. In some cases, you may want to handle all these nodes as a group. For example, when you create a variable, there is the `trainable=` kwarg. When set to True (default), the variable will be included into the `TRAINABLE_VARIABLES` collection. Having this distintion may be useful, for example, when you want to save some variables of your model: you might want to save only trainable variables (e.g. network weights), ignoring other variables which are relevant only during the execution.\n",
"\n",
"Same thing is for the `tf.GraphKeys.QUEUE_RUNNERS` collection: it is intended to collect all the queue runners, so that you can start them when required.\n",
"\n",
"In this way, you can build a program using queues in a modular way: you create a queue and some operations with it, then you create a `QueueRunner` that enqueue something into it and you just need something that consumes values from that queue.\n",
"\n",
"### The case of slice_input_producer\n",
"\n",
"We will now see an example of this pattern using `tf.train.slice_input_producer()`.\n",
"\n",
"The use-case is the following: we have a small dataset that fits into a GPU, and we want to get high performance training. Right now, our code uses `feed_dict`, but [we know](https://www.tensorflow.org/performance/performance_guide#utilize_queues_for_reading_data) that using queues could lead to better performances.\n",
"\n",
"We read [somewhere](https://www.tensorflow.org/programmers_guide/reading_data#preloaded_data) that one strategy for implementing data reading is to put all your data into the graph, using a constant. This is fine in our case, because the dataset is rather small and out GPU has plenty of memory. That page states that\n",
"\n",
"> This is only used for small data sets that can be loaded entirely in memory. There are two approaches:\n",
"> - Store the data in a constant.\n",
"> - Store the data in a variable, that you initialize and then never change.\n",
">\n",
"> [...]\n",
">\n",
"> Either way, tf.train.slice_input_producer can be used to produce a slice at a time\n",
"\n",
"But how do we use the `tf.train.slice_input_producer`? The following code is not working:\n",
"\n",
"```\n",
"# Create some data\n",
"dataset = np.arange(15).reshape((5, 3)).astype(np.float32)\n",
"# Make it constant in the graph\n",
"const_data = tf.constant(dataset)\n",
"with tf.Session() as sess:\n",
" # Input producer gets a list and returns a list\n",
" x, = tf.train.slice_input_producer([const_data], shuffle=False)\n",
" print(x.eval()) # Do we get the first row of the data? Nein!\n",
"```\n",
"\n",
"Why is it not working? Because queues and QueueRunners are used underneath, and we need to start the threads."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 1. 2.]\n",
"[ 3. 4. 5.]\n",
"[ 6. 7. 8.]\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"dataset = np.arange(15).reshape([5, 3]).astype(np.float32)\n",
"const_data = tf.constant(dataset)\n",
"\n",
"row, = tf.train.slice_input_producer([const_data], num_epochs=1, shuffle=False)\n",
"\n",
"with tf.Session() as sess:\n",
" tf.group(\n",
" tf.global_variables_initializer(),\n",
" tf.local_variables_initializer()\n",
" ).run()\n",
"\n",
" coord = tf.train.Coordinator()\n",
" threads = tf.train.start_queue_runners(sess=sess, coord=coord)\n",
"\n",
" for i in range(3): # Try a value >5\n",
" print(sess.run(row))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Well, now it works as expected! What about the `num_epochs` parameter? It is set to 1, meaning that we can loop over the data just once. If you try setting the number of iterations to 10 (or any value greater than the number of rows), the program will stop with an `OutOfRangeError`: `num_epochs` specifies how many iterations on the data we can do before getting that exception.\n",
"\n",
"If `num_epoch` is not specified, it will go on forever... But this means that we have to stop the threads, using the coordinator, or we will get an error (in particular, the session will be closed and an operation over a closed session will be attempted, causing an error)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 1. 2.]\n",
"[ 3. 4. 5.]\n",
"[ 6. 7. 8.]\n",
"[ 9. 10. 11.]\n",
"[ 12. 13. 14.]\n",
"[ 0. 1. 2.]\n",
"[ 3. 4. 5.]\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"dataset = np.arange(15).reshape([5, 3]).astype(np.float32)\n",
"const_data = tf.constant(dataset)\n",
"\n",
"row, = tf.train.slice_input_producer([const_data], shuffle=False)\n",
"\n",
"with tf.Session() as sess:\n",
" tf.group(\n",
" tf.global_variables_initializer(),\n",
" tf.local_variables_initializer()\n",
" ).run()\n",
"\n",
" coord = tf.train.Coordinator()\n",
" threads = tf.train.start_queue_runners(sess=sess, coord=coord)\n",
"\n",
" for i in range(7):\n",
" print(sess.run(row))\n",
" \n",
" # We are done reading\n",
" coord.request_stop()\n",
" coord.join(threads) # Wait for the threads to actually stop"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are able to start the queue runner inside `slice_input_producer()` because it was added to QUEUE_RUNNER collection, and `start_queue_runners` acted upon them.\n",
"\n",
"Also, note how we get, named `row`, a dequeue operation that we can use to retrieve the data we want.\n",
"\n",
"Just to show another case in which QueueRunner is used, we expand this example by including a `tf.train.batch` node to retrieve multiple values at once. Again, you can read in the [doc](https://www.tensorflow.org/versions/master/api_docs/python/tf/train/batch) that\n",
"\n",
"> This function is implemented using a queue. A QueueRunner for the queue is added to the current Graph's QUEUE_RUNNER collection.\n",
"\n",
"Luckly, we already know how to setup things, and we know that `start_queue_runners()` will act on every queue runner in the collection, including the one present in the batch operation"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2.]\n",
" [ 3. 4. 5.]\n",
" [ 6. 7. 8.]]\n",
"[[ 9. 10. 11.]\n",
" [ 12. 13. 14.]\n",
" [ 0. 1. 2.]]\n",
"[[ 3. 4. 5.]\n",
" [ 6. 7. 8.]\n",
" [ 9. 10. 11.]]\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"\n",
"dataset = np.arange(15).reshape([5, 3]).astype(np.float32)\n",
"const_data = tf.constant(dataset)\n",
"\n",
"row, = tf.train.slice_input_producer([const_data], shuffle=False)\n",
"\n",
"# If the input is just one element, it returns a tensor and not a list\n",
"batch = tf.train.batch([row], batch_size=3, shapes=[3])\n",
"\n",
"with tf.Session() as sess:\n",
" tf.group(\n",
" tf.global_variables_initializer(),\n",
" tf.local_variables_initializer()\n",
" ).run()\n",
"\n",
" coord = tf.train.Coordinator()\n",
" threads = tf.train.start_queue_runners(sess=sess, coord=coord)\n",
"\n",
" for i in range(3):\n",
" print(sess.run(batch))\n",
" \n",
" # We are done reading\n",
" coord.request_stop()\n",
" coord.join(threads) # Wait for the threads to actually stop"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "TensorFlow",
"language": "python",
"name": "tensorflow"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment