Created
March 27, 2019 09:51
-
-
Save IlievskiV/1427743b9f4dce4a32d0d9326dacc565 to your computer and use it in GitHub Desktop.
Loading data in TensorFlow, using low-level queuing mechanism
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class Feeder: | |
def __init__(self): | |
self.sess = tf.Session() # session to run operations | |
self.coordinator = tf.train.Coordinator() # create coordinator for threads | |
self.placeholder = tf.placeholder(tf.int32, shape=(None, None), name='sentence') | |
self.queue = tf.FIFOQueue(5, tf.int32, name='input_queue') # hold 5 elements | |
self.enqueue_op = queue.enqueue(placeholder) # push op | |
self.sentence = queue.dequeue() # pop op | |
# The main procedure | |
def enqueue(batch_size=5, num_batches=5): | |
while not self.coordinator.should_stop(): | |
# 1 and 2: read data from dist and preprocess | |
examples = [self.get_next_example() for _ in range(batch_size * num_batches)] | |
# 3: Make batches | |
batches = [[examples[i: i + batch_size]] for i in range(0, len(examples), batch_size)] | |
# 4: Insert the batches in the queue | |
[sess.run(enqueue_op, feed_dict=dict(zip(placeholder, batch))) for batch in batches] | |
# Repeat the same process over and over | |
def start_threads(self): | |
thread = threading.Thread(name='background', target=self.enqueue) | |
thread.daemon = True | |
thread.start() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment