Last active
October 28, 2017 14:22
-
-
Save bzamecnik/3c2b5279a5949d694421d7cfbe813557 to your computer and use it in GitHub Desktop.
Proof of concept of using Keras with StagingArea - data is fed separately via a tf.Variable and in keras Callback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Is it possible to utilize Keras callbacks to encapsulate the logic? Yes. | |
# | |
# We decouple feeding inputs from StagingArea.put() - both can be called in | |
# a separate Session.run(). Thus it's not needed to hack Keras inputs too much. | |
# Instead in one run() we assign a numpy array to a Variable (via feed_dict) | |
# and in another run() we perform StagingArea.put(). | |
# | |
# We make a callback PrefetchCallback which perform the initial assign and put() | |
# in its on_epoch_begin() method. Then in each on_batch_begin() it just runs an | |
# assign. Then get() and put() is ran by Keras in the training function. | |
# | |
# It is able to slice the input array to batches and also for the last batch | |
# it provides a dummy value which is discarded, so that we can leave get() + put() | |
# uniform over all batches. | |
# | |
# Requires patches Keras: https://github.com/bzamecnik/keras/commit/8593309c371ce716fd039e33ed5ae4079096ee0f | |
import tensorflow as tf | |
from keras.datasets import mnist | |
from keras.models import Model | |
from keras.layers import Dense, Input | |
from keras.utils import to_categorical | |
import numpy as np | |
import keras.backend as K | |
from keras.callbacks import Callback | |
K.set_session(None) | |
tf.reset_default_graph() | |
num_classes = 10 | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train = x_train.reshape(60000, 784).astype('float32') / 255 | |
y_train = to_categorical(y_train, num_classes).astype('float32') | |
batch_size = 64 | |
steps_per_epoch = len(x_train) // batch_size | |
features_shape = (batch_size, 784) | |
labels_shape = (batch_size, num_classes) | |
# for feeding inputs to the the StagingArea | |
# Let's try to decouple feeding data to StagingArea.put() | |
# from the training batch session.run() | |
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data | |
features_batch_next_value = tf.placeholder(dtype=tf.float32, shape=(batch_size, 784)) | |
features_batch_next = tf.Variable(features_batch_next_value, trainable=False, collections=[]) | |
labels_batch_next_value = tf.placeholder(dtype=tf.float32, shape=(batch_size, num_classes)) | |
labels_batch_next = tf.Variable(labels_batch_next_value, trainable=False, collections=[]) | |
assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer) | |
# for prefetching to GPU | |
area = tf.contrib.staging.StagingArea( | |
dtypes=[tf.float32, tf.float32], | |
shapes=[features_shape, labels_shape]) | |
area_put = area.put([features_batch_next.value(), labels_batch_next.value()]) | |
area_get_features, area_get_labels = area.get() | |
area_size = area.size() | |
area_clear = area.clear() | |
image = Input(tensor=area_get_features) | |
x = Dense(512, activation='relu')(image) | |
digit = Dense(num_classes, activation='softmax')(x) | |
model = Model(inputs=image, outputs=digit) | |
class PrefetchCallback(Callback): | |
def __init__(self, x, y, batch_size): | |
self.x = x | |
self.y = y | |
self.batch_size = batch_size | |
self.steps_per_epoch = len(x) // batch_size | |
# 1 batch prefetched to the pipeline | |
self.prefetch_count = 1 | |
def _slice_batch(self, i): | |
start = i * self.batch_size | |
end = start + self.batch_size | |
return (self.x[start:end], self.y[start:end]) | |
def _assign_batch(self, session, data): | |
x_batch, y_batch = data | |
session.run(assign_next_batch, feed_dict={ | |
features_batch_next_value: x_batch, | |
labels_batch_next_value: y_batch}) | |
def on_epoch_begin(self, epoch, logs=None): | |
# print('initial size:', K.get_session().run(area_size)) | |
sess = K.get_session() | |
self._assign_batch(sess, self._slice_batch(0)) | |
sess.run(area_put) | |
# print('size after first put():', K.get_session().run(area_size)) | |
def on_batch_begin(self, batch, logs=None): | |
sess = K.get_session() | |
if batch < self.steps_per_epoch - self.prefetch_count: | |
data = self._slice_batch(batch + self.prefetch_count) | |
else: | |
# a dummy value for the last batch which is not used anyway | |
data = (np.zeros((batch_size, self.x.shape[1])), np.zeros((batch_size, self.y.shape[1]))) | |
self._assign_batch(sess, data) | |
# print('size before batch ', batch, ':', sess.run(area_size)) | |
def on_epoch_end(self, epoch, logs=None): | |
sess = K.get_session() | |
sess.run(area_clear) | |
# print('size at the end (should be 0):', sess.run(area_size)) | |
model.compile(optimizer='sgd', loss='categorical_crossentropy', | |
target_tensors=[area_get_labels], fetches=[area_put]) | |
prefetch_callback = PrefetchCallback(x_train, y_train, batch_size) | |
model.fit(steps_per_epoch=steps_per_epoch, callbacks=[prefetch_callback]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment