Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active June 20, 2018 11:17
Show Gist options
  • Save bzamecnik/8a333fcfab9fb473fac4b7976338d9e7 to your computer and use it in GitHub Desktop.
Save bzamecnik/8a333fcfab9fb473fac4b7976338d9e7 to your computer and use it in GitHub Desktop.
GPU async memcpy in Keras 2.2.0 / TF 1.8 using tf.data.prefetch_to_device. It works!
# Works in Keras 2.1.0/2.2.0, TF 1.8!
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input, Conv2D, Flatten
from keras.utils import to_categorical
import numpy as np
import keras.backend as K
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
batch_size = 600
x_train = x_train.reshape(60000, 28, 28, 1).astype('float32') / 255
y_train = to_categorical(y_train, num_classes).reshape(60000, 10).astype('float32')
def data_generator():
for i in range(100):
start = i * batch_size
end = (i+1) * batch_size
yield x_train[start:end], y_train[start:end]
dataset = tf.data.Dataset.from_generator(
data_generator,
(tf.float32, tf.float32),
(tf.TensorShape([600,28,28,1]), tf.TensorShape([600,10]))
).repeat()
dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/gpu:0', 1))
iter = dataset.make_one_shot_iterator()
features_batch_next, labels_batch_next = iter.get_next()
print('features_batch_next', features_batch_next.shape)
print('labels_batch_next', labels_batch_next.shape)
features_shape = (batch_size, 28, 28, 1)
labels_shape = (batch_size, num_classes)
image = Input(shape=(28, 28, 1))
x = Conv2D(32, 19, padding='same', activation='relu')(image)
x = Conv2D(32, 19, padding='same', activation='relu')(x)
x = Flatten()(x)
digit = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=image, outputs=digit)
model.compile(optimizer='sgd', loss='categorical_crossentropy')
model.fit(features_batch_next, labels_batch_next, steps_per_epoch=10)
@bzamecnik
Copy link
Author

bzamecnik commented Jun 19, 2018

Rev 1:
snimek obrazovky 2018-06-20 v 0 30 23

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment