Created
March 21, 2017 20:55
-
-
Save Dref360/2c8280fc497df690ebb35646f7021a62 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import keras | |
import keras.backend as K | |
from keras.models import Model | |
import keras.callbacks as cbks | |
import threading | |
from keras.layers import Layer, Input, InputLayer, Dense | |
class FIFOModel (Model): | |
def fit_queuer(self,enqueue_op, enqueue_placeholder, generator, | |
steps_per_epoch, | |
epochs=1, | |
verbose=1, | |
callbacks=None, | |
validation_data=None, | |
validation_steps=None, | |
class_weight=None, | |
max_q_size=10, | |
workers=1, | |
pickle_safe=False, | |
initial_epoch=0): | |
epoch = initial_epoch | |
do_validation = bool(validation_data) | |
self._make_train_function() | |
if do_validation: | |
self._make_test_function() | |
# python 2 has 'next', 3 has '__next__' | |
# avoid any explicit version checks | |
val_gen = (hasattr(validation_data, 'next') or | |
hasattr(validation_data, '__next__')) | |
if val_gen and not validation_steps: | |
raise ValueError('When using a generator for validation data, ' | |
'you must specify a value for ' | |
'`validation_steps`.') | |
out_labels = self.metrics_names | |
callback_metrics = out_labels + ['val_' + n for n in out_labels] | |
# prepare callbacks | |
self.history = cbks.History() | |
callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history] | |
if verbose: | |
callbacks += [cbks.ProgbarLogger(count_mode='steps')] | |
callbacks = cbks.CallbackList(callbacks) | |
# it's possible to callback a different model than self: | |
if hasattr(self, 'callback_model') and self.callback_model: | |
callback_model = self.callback_model | |
else: | |
callback_model = self | |
callbacks.set_model(callback_model) | |
callbacks.set_params({ | |
'epochs': epochs, | |
'steps': steps_per_epoch, | |
'verbose': verbose, | |
'do_validation': do_validation, | |
'metrics': callback_metrics, | |
}) | |
callbacks.on_train_begin() | |
if do_validation and not val_gen: | |
if len(validation_data) == 2: | |
val_x, val_y = validation_data | |
val_sample_weight = None | |
elif len(validation_data) == 3: | |
val_x, val_y, val_sample_weight = validation_data | |
else: | |
raise ValueError('validation_data should be a tuple ' | |
'`(val_x, val_y, val_sample_weight)` ' | |
'or `(val_x, val_y)`. Found: ' + | |
str(validation_data)) | |
val_x, val_y, val_sample_weights = self._standardize_user_data( | |
val_x, val_y, val_sample_weight) | |
for cbk in callbacks: | |
cbk.validation_data = val_x + [val_y, val_sample_weights] | |
def feeding_func(coord,gen): | |
while not coord.should_stop(): | |
x,y = next(gen) | |
print("FEEDING") | |
K.get_session().run(enqueue_op,feed_dict={enqueue_placeholder[0]:x,enqueue_placeholder[1] : y}) | |
coord = tf.train.Coordinator() | |
threads = [threading.Thread(target=feeding_func, args=(coord,generator)) for i in range(workers)] | |
for t in threads: | |
t.start() | |
for i in range(epochs): | |
print("TRAIN") | |
self.train_on_batch() | |
if i - epochs < max_q_size: | |
coord.request_stop() | |
print("STOP") | |
coord.join(threads) | |
print("DONE") | |
def train_on_batch(self, **kwargs): | |
outputs = self.train_function([np.ones([10])]) | |
if len(outputs) == 1: | |
return outputs[0] | |
return outputs | |
def _make_train_function(self): | |
if not hasattr(self, 'train_function'): | |
raise RuntimeError('You must compile your model before using it.') | |
if self.train_function is None: | |
inputs = self._feed_sample_weights | |
if self.uses_learning_phase and not isinstance(K.learning_phase(), int): | |
inputs += [K.learning_phase()] | |
training_updates = self.optimizer.get_updates( | |
self._collected_trainable_weights, | |
self.constraints, | |
self.total_loss) | |
updates = self.updates + training_updates | |
# Gets loss and metrics. Updates weights at each call. | |
self.train_function = K.function(inputs, | |
[self.total_loss] + self.metrics_tensors, | |
updates=updates, | |
**self._function_kwargs) | |
with K.get_session() as sess: | |
shp = [10,200] | |
shp1 = [10,10] | |
inp = K.placeholder(shp) | |
inp1 = K.placeholder(shp1) | |
queue = tf.FIFOQueue(20,[tf.float32,tf.float32],[[10,200],[10,10]]) | |
x1,y1 = queue.dequeue() | |
enqueue = queue.enqueue([inp,inp1]) | |
inputLayer = Input(batch_shape=[10,200],tensor=x1) | |
d = Dense(10)(inputLayer) | |
d1 = Dense(10)(d) | |
model = FIFOModel(inputLayer,d1) | |
model.compile('rmsprop','mse',targets=y1) | |
def genera(): | |
while True: | |
yield np.arange(10 * 200).reshape(shp), np.arange(10 * 10).reshape([10,10]) | |
model.fit_queuer(enqueue,[inp,inp1],genera(),1000,10,max_q_size=10,workers=1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for an example of using Keras with TF queues.
After changes with constraints in Keras 2.0.7 it's broken:
Fixes:
constraints
from optimizer.get_updatestargets=y1
->target_tensors=[y1]
Then it fails on anyway:
Putting
sess.run(tf.global_variables_initializer())
aftermodel.compile()
helped.Fixed code: