Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active October 28, 2017 14:22
Show Gist options
  • Save bzamecnik/3c2b5279a5949d694421d7cfbe813557 to your computer and use it in GitHub Desktop.
Save bzamecnik/3c2b5279a5949d694421d7cfbe813557 to your computer and use it in GitHub Desktop.
Proof of concept of using Keras with StagingArea - data is fed separately via a tf.Variable and in keras Callback
# Is it possible to utilize Keras callbacks to encapsulate the logic? Yes.
#
# We decouple feeding inputs from StagingArea.put() - both can be called in
# a separate Session.run(). Thus it's not needed to hack Keras inputs too much.
# Instead in one run() we assign a numpy array to a Variable (via feed_dict)
# and in another run() we perform StagingArea.put().
#
# We make a callback PrefetchCallback which perform the initial assign and put()
# in its on_epoch_begin() method. Then in each on_batch_begin() it just runs an
# assign. Then get() and put() is ran by Keras in the training function.
#
# It is able to slice the input array to batches and also for the last batch
# it provides a dummy value which is discarded, so that we can leave get() + put()
# uniform over all batches.
#
# Requires patches Keras: https://github.com/bzamecnik/keras/commit/8593309c371ce716fd039e33ed5ae4079096ee0f
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical
import numpy as np
import keras.backend as K
from keras.callbacks import Callback
K.set_session(None)
tf.reset_default_graph()
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
y_train = to_categorical(y_train, num_classes).astype('float32')
batch_size = 64
steps_per_epoch = len(x_train) // batch_size
features_shape = (batch_size, 784)
labels_shape = (batch_size, num_classes)
# for feeding inputs to the the StagingArea
# Let's try to decouple feeding data to StagingArea.put()
# from the training batch session.run()
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data
features_batch_next_value = tf.placeholder(dtype=tf.float32, shape=(batch_size, 784))
features_batch_next = tf.Variable(features_batch_next_value, trainable=False, collections=[])
labels_batch_next_value = tf.placeholder(dtype=tf.float32, shape=(batch_size, num_classes))
labels_batch_next = tf.Variable(labels_batch_next_value, trainable=False, collections=[])
assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer)
# for prefetching to GPU
area = tf.contrib.staging.StagingArea(
dtypes=[tf.float32, tf.float32],
shapes=[features_shape, labels_shape])
area_put = area.put([features_batch_next.value(), labels_batch_next.value()])
area_get_features, area_get_labels = area.get()
area_size = area.size()
area_clear = area.clear()
image = Input(tensor=area_get_features)
x = Dense(512, activation='relu')(image)
digit = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=image, outputs=digit)
class PrefetchCallback(Callback):
def __init__(self, x, y, batch_size):
self.x = x
self.y = y
self.batch_size = batch_size
self.steps_per_epoch = len(x) // batch_size
# 1 batch prefetched to the pipeline
self.prefetch_count = 1
def _slice_batch(self, i):
start = i * self.batch_size
end = start + self.batch_size
return (self.x[start:end], self.y[start:end])
def _assign_batch(self, session, data):
x_batch, y_batch = data
session.run(assign_next_batch, feed_dict={
features_batch_next_value: x_batch,
labels_batch_next_value: y_batch})
def on_epoch_begin(self, epoch, logs=None):
# print('initial size:', K.get_session().run(area_size))
sess = K.get_session()
self._assign_batch(sess, self._slice_batch(0))
sess.run(area_put)
# print('size after first put():', K.get_session().run(area_size))
def on_batch_begin(self, batch, logs=None):
sess = K.get_session()
if batch < self.steps_per_epoch - self.prefetch_count:
data = self._slice_batch(batch + self.prefetch_count)
else:
# a dummy value for the last batch which is not used anyway
data = (np.zeros((batch_size, self.x.shape[1])), np.zeros((batch_size, self.y.shape[1])))
self._assign_batch(sess, data)
# print('size before batch ', batch, ':', sess.run(area_size))
def on_epoch_end(self, epoch, logs=None):
sess = K.get_session()
sess.run(area_clear)
# print('size at the end (should be 0):', sess.run(area_size))
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[area_get_labels], fetches=[area_put])
prefetch_callback = PrefetchCallback(x_train, y_train, batch_size)
model.fit(steps_per_epoch=steps_per_epoch, callbacks=[prefetch_callback])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment