Last active
March 1, 2017 16:56
-
-
Save post2web/e6c897a6b0456335a34f83685decea31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
batch_size = 2 | |
get_single_xy_timeout = 10 | |
train_timeout = 3 | |
n_epochs = 5 | |
n_threads = 5 | |
# dummy function simulating loading, preprocessing | |
# and other operations needed for the X and Y | |
class DataFetcher(): | |
def __init__(self): | |
self.counter = 0 | |
def get_single_xy(self): | |
# can have a state | |
self.counter += 1 | |
print(self.counter) | |
time.sleep(get_single_xy_timeout) | |
# types of x, y have to match queue | |
x = np.random.rand(2,2).astype(np.float32) | |
y = np.random.rand(1).astype(np.float32)[0] | |
return [x, y] | |
data = DataFetcher() | |
# create the query | |
queue = tf.FIFOQueue( | |
capacity=15, | |
dtypes=[tf.float32, tf.float32], | |
shapes=[[2,2], []], | |
) | |
python_data_op = tf.py_func(data.get_single_xy, inp=[], Tout=[tf.float32, tf.float32]) | |
# Enqueues (add) one element to this queue. | |
enqueue_op = queue.enqueue(python_data_op) | |
# Dequeues (remove) one element from this queue. | |
dequeue_op = queue.dequeue() | |
# Dequeues and concatenates `n` elements from this queue. | |
X, Y = queue.dequeue_many(n=batch_size) | |
# dummy train operation | |
train_op = tf.reduce_mean(tf.reduce_mean(X) * Y) | |
# Create a queue runner that will run 4 threads in parallel to enqueue examples | |
qr = tf.train.QueueRunner(queue, [enqueue_op] * n_threads) | |
init_op = tf.global_variables_initializer() | |
sess = tf.Session() | |
sess.run(init_op) | |
# Create a coordinator, launch the queue runner threads. | |
coord = tf.train.Coordinator() | |
threads = qr.create_threads(sess, coord=coord, start=True) | |
try: | |
for step in range(n_epochs): | |
if coord.should_stop(): | |
break | |
# inside the train loop | |
start_time = time.time() | |
result = sess.run(train_op) | |
time.sleep(train_timeout) | |
print('Result', result, 'Time:', time.time() - start_time) | |
except Exception as e: | |
# Report exceptions to the coordinator. | |
coord.request_stop(e) | |
finally: | |
# Terminate as usual. It is innocuous to request stop twice. | |
coord.request_stop() | |
coord.join(threads) | |
sess.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/tensorflow/tensorflow/issues/2514#issuecomment-221934925 | |
import tensorflow as tf | |
from threading import Thread | |
batch_size = 2 | |
get_single_xy_timeout = 10 | |
train_timeout = 3 | |
n_steps = 5 | |
n_threads = 3 | |
q_capacity = 3 | |
# dummy function simulating loading, preprocessing | |
# and other operations needed for the X and Y | |
class Dataset(): | |
def __init__(self): | |
self.counter = 0 | |
self.x_dtype = tf.float32 | |
self.x_shape = [2,2] | |
self.y_dtype = tf.float32 | |
self.y_shape = [] | |
self.dtypes = [self.x_dtype, self.y_dtype] | |
self.shapes = [self.x_shape, self.y_shape] | |
def get_single_xy(self, train=True): | |
# can have a state | |
self.counter += 1 | |
time.sleep(get_single_xy_timeout) | |
# types of x, y have to match queue | |
x = np.random.rand(2,2) | |
if train: | |
y = np.random.rand() | |
else: | |
y = 100000. | |
return [x, y] | |
dataset = Dataset() | |
q_train = tf.FIFOQueue( | |
capacity=q_capacity, | |
dtypes=dataset.dtypes, | |
shapes=dataset.shapes | |
) | |
X_train = tf.placeholder(dataset.x_dtype, dataset.x_shape) | |
Y_train = tf.placeholder(dataset.y_dtype, dataset.y_shape) | |
enqueue_train = q_train.enqueue([X_train, Y_train]) | |
q_test = tf.FIFOQueue( | |
capacity=q_capacity, | |
dtypes=dataset.dtypes, | |
shapes=dataset.shapes | |
) | |
X_test = tf.placeholder(dataset.x_dtype, dataset.x_shape) | |
Y_test = tf.placeholder(dataset.y_dtype, dataset.y_shape) | |
enqueue_test = q_test.enqueue([X_test, Y_test]) | |
q_selector = tf.placeholder(tf.int32, []) | |
q = tf.QueueBase.from_list(q_selector, [q_train, q_test]) | |
dequeue_op = q.dequeue() | |
X, Y = q.dequeue_many(n=batch_size) | |
# dummy train operation | |
train_op = tf.reduce_mean(tf.reduce_mean(X) * Y) | |
init_op = tf.global_variables_initializer() | |
sess = tf.Session() | |
sess.run(init_op) | |
# Create a coordinator, launch the queue runner threads. | |
coordinator = tf.train.Coordinator() | |
def train_pusher(): | |
with coordinator.stop_on_exception(): | |
while not coordinator.should_stop(): | |
x, y = dataset.get_single_xy(train=True) | |
sess.run(enqueue_train, { X_train: x, Y_train: y }) | |
def test_pusher(): | |
with coordinator.stop_on_exception(): | |
while not coordinator.should_stop(): | |
x, y = dataset.get_single_xy(train=False) | |
sess.run(enqueue_test, { X_test: x, Y_test: y }) | |
threads = [Thread(target=train_pusher) for i in range(n_threads)] | |
threads += [Thread(target=test_pusher) for i in range(n_threads)] | |
[t.start() for t in threads] | |
try: | |
for step in range(n_steps): | |
# if something not right stop | |
if coordinator.should_stop(): | |
break | |
start_time = time.time() | |
result = sess.run(train_op, {q_selector: 0}) | |
time.sleep(train_timeout) | |
print('Result', result, 'Time:', time.time() - start_time) | |
for step in range(n_steps): | |
start_time = time.time() | |
result = sess.run(train_op, {q_selector: 1}) | |
time.sleep(train_timeout) | |
print('Result', result, 'Time:', time.time() - start_time) | |
except Exception as e: | |
# Report exceptions to the coordinator. | |
coordinator.request_stop(e) | |
finally: | |
# Terminate as usual. It is innocuous to request stop twice. | |
coordinator.request_stop() | |
[t.join(0) for t in threads] | |
sess.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
batch_size = 2 | |
get_single_xy_timeout = 10 | |
train_timeout = 3 | |
n_epochs = 20 | |
# dummy function simulating loading, preprocessing | |
# and other operations needed for the X and Y | |
def get_single_xy(): | |
time.sleep(get_single_xy_timeout) | |
x = np.random.rand(2,2) | |
y = random.random() | |
return x, y | |
# placeholders | |
X = tf.placeholder(tf.float32, [2, 2]) | |
Y = tf.placeholder(tf.float32) | |
# create the query | |
queue = tf.FIFOQueue( | |
capacity=15, | |
dtypes=[tf.float32, tf.float32], | |
shapes=[[2,2], []], | |
) | |
# Enqueues (add) one element to this queue. | |
enqueue_op = queue.enqueue([X, Y]) | |
# Dequeues (remove) one element from this queue. | |
dequeue_op = queue.dequeue() | |
# Dequeues and concatenates `n` elements from this queue. | |
Xs, Ys = queue.dequeue_many(n=batch_size) | |
# dummy train operation | |
train_op = tf.reduce_mean(Xs) * Ys | |
init_op = tf.global_variables_initializer() | |
sess = tf.Session() | |
sess.run(init_op) | |
# A coordinator for threads | |
coord = tf.train.Coordinator() | |
def enqueue_thread(): | |
# Context manager to request stop when an Exception is raised. | |
with coord.stop_on_exception(): | |
while not coord.should_stop(): | |
x, y = get_single_xy() | |
sess.run(enqueue_op, feed_dict={X: x, Y: y}) | |
available_threads = 5 | |
for _ in range(available_threads): | |
threading.Thread(target=enqueue_thread).start() | |
for epoch in range(n_epochs): | |
start_time = time.time() | |
sess.run(train_op) | |
time.sleep(train_timeout) | |
print('Time:', time.time() - start_time) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment