Skip to content

Instantly share code, notes, and snippets.

@psycharo-zz
Created May 8, 2017 09:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save psycharo-zz/3f84c24c4666725ee3dbf5f55cd14aa0 to your computer and use it in GitHub Desktop.
Save psycharo-zz/3f84c24c4666725ee3dbf5f55cd14aa0 to your computer and use it in GitHub Desktop.
custom multi-threading runner for tensorflow
import threading
import numpy as np
import tensorflow as tf
class FeedingRunner(object):
"""Takes care of feeding/dequeueing data into the queue
Based on tf.train.QueueRunner
"""
def __init__(self, generator, dtypes, shapes, names, num_threads,
queue_capacity):
"""
Args:
generator: generator that returns more data
dtypes: a list of types of inputs
shapes: a list of shapes of the inputs
names: a list of names of the inputs
num_threads: how many threads to have
queue_capacity: number of sample to keep in the examples queue
"""
assert len(dtypes) == len(shapes) == len(names)
self._generator = generator
self._num_threads = num_threads
self._queue = tf.FIFOQueue(queue_capacity, dtypes)
self._dtypes = dtypes
self._shapes = shapes
self._names = names
self._placeholders = [tf.placeholder(dtype, shape)
for dtype, shape in zip(dtypes, shapes) ]
self._enqueue_op = self._queue.enqueue(self._placeholders)
self._dequeue_op = self._queue.dequeue()
# dequeue returns list when there are multiple tensors, and
if type(self._dequeue_op) != list:
self._dequeue_op = [self._dequeue_op]
self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
self._inputs = []
for i, value in enumerate(self._dequeue_op):
value.set_shape(self._shapes[i])
self._inputs.append(tf.identity(value, self._names[i]))
def _run(self, sess, coord):
"""Runs the cycle that feeds data into the queue"""
try:
for values in self._generator:
if coord and coord.should_stop():
break
feed_dict = { key : value
for key, value in zip(self._placeholders, values) }
sess.run(self._enqueue_op, feed_dict)
except Exception as e:
if coord:
coord.request_stop(e)
def _close_on_stop(self, sess, cancel_op, coord):
"""Close the queue when the Coordinator requests stop.
Args:
sess: A Session.
cancel_op: The Operation to run.
coord: Coordinator.
"""
coord.wait_for_stop()
try:
sess.run(cancel_op)
except Exception as e:
tf.logging.vlog(1, 'Ignored exception: %s', str(e))
def create_threads(self, sess, coord=None, daemon=False, start=False):
threads = [threading.Thread(target=self._run, args=(sess, coord))
for i in range(self._num_threads)]
if coord:
threads.append(threading.Thread(target=self._close_on_stop,
args=(sess, self._cancel_op, coord)))
for t in threads:
if coord:
coord.register_thread(t)
if daemon:
t.daemon = True
if start:
t.start()
return threads
@property
def queue(self):
return self._queue
@property
def inputs(self):
return self._inputs
class RandomDataIterator(object):
"""Iterator for uniform-random sampling from the dataset"""
def __init__(self, filenames, readers, batch_size, replace=False):
"""
Args:
filenames: list of tuples of filenames
readers: list of (threadsafe) functions taking filename and returning
data
batch_size: int > 0
replace: bool, whether to sample with replacement
"""
assert len(filenames) >= 1, len(filenames) == len(readers)
self.filenames = filenames
self.readers = readers
self.batch_size = batch_size
self.replace = replace
def __iter__(self):
return self
def __next__(self):
idxs = np.random.choice(len(self.filenames[0]), self.batch_size, self.replace)
batch = []
for fid, reader in enumerate(self.readers):
batch.append([reader(self.filenames[fid][idx])
for idx in idxs])
return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment