Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Created March 1, 2019 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bzamecnik/368fc4b43847c449b3c297fc0056b2ba to your computer and use it in GitHub Desktop.
Save bzamecnik/368fc4b43847c449b3c297fc0056b2ba to your computer and use it in GitHub Desktop.
Example of basic MNIST Keras model with tf.Dataset
# Example of basic MNIST Keras model with tf.Dataset
# More up-to-date version of: https://github.com/keras-team/keras/blob/master/examples/mnist_dataset_api.py
"""
MNIST classification with TensorFlow's Dataset API.
Introduced in TensorFlow 1.3, the Dataset API is now the
standard method for loading data into TensorFlow models.
A Dataset is a sequence of elements, which are themselves
composed of tf.Tensor components. For more details, see:
https://www.tensorflow.org/programmers_guide/datasets
To use this with Keras, we make a dataset out of elements
of the form (input batch, output batch). From there, we
create a one-shot iterator and a graph node corresponding
to its get_next() method. These tensors are then provided
to the network instead of plain numpy arrays.
See also the mnist_tfrecord.py example.
"""
import numpy as np
import os
import tempfile
import keras
from keras import backend as K
from keras import layers
from keras.datasets import mnist
import tensorflow as tf
if K.backend() != 'tensorflow':
raise RuntimeError('This example can only run with the TensorFlow backend,'
' because it requires the Datset API, which is not'
' supported on other platforms.')
batch_size = 128
shuffle_size = 1024
epochs = 5
num_classes = 10
def build_model():
input = layers.Input(shape=(28, 28))
# add a dimension for conv channels
x = layers.Lambda(K.expand_dims)(input)
x = layers.Conv2D(32, (3, 3),
activation='relu', padding='valid')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
x = layers.Conv2D(64, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
x = layers.Flatten()(x)
x = layers.Dense(512, activation='relu')(x)
x = layers.Dropout(0.5)(x)
output = layers.Dense(num_classes, activation='softmax')(x)
model = keras.models.Model(inputs=input, outputs=output)
model.compile(optimizer=keras.optimizers.RMSprop(lr=2e-3, decay=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
def make_dataset(x, y, shuffle=False):
def preprocess(image, label):
"""Preprocess raw data to trainable input."""
x = tf.cast(image, tf.float32) / 255
y = tf.one_hot(tf.cast(label, tf.uint8), num_classes)
return x, y
# NOTE: This stored the provided numpy arrays into the
# TF graph as constants! It's only useful for small data.
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.map(preprocess)
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(shuffle_size)
# Keras does not support tensors with dynamic batch size
dataset = dataset.batch(batch_size, drop_remainder=True)
iterator = dataset.make_one_shot_iterator()
inputs, targets = iterator.get_next()
return inputs, targets
# numpy arrays
(x_train, y_train), (x_test, y_test) = mnist.load_data()
model = build_model()
model.summary()
# tensors
inputs_train, targets_train = make_dataset(x_train, y_train, shuffle=True)
inputs_test, targets_test = make_dataset(x_test, y_test, shuffle=False)
steps_per_epoch = int(np.ceil(len(x_train) / float(batch_size))) # = 469
validation_steps = int(np.ceil(len(x_test) / float(batch_size))) # = 79
# Since upstream Keras 2.2.0 it's possible to provide tensors for
# training and validation inputs/outputs, while tf.keras directly
# accepts a tf.data.Dataset.
model.fit(x=inputs_train,
y=targets_train,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
validation_data=(inputs_test, targets_test),
validation_steps=validation_steps)
loss, acc = model.evaluate(inputs_test, targets_test, steps=validation_steps)
print('\nTest accuracy: {0}'.format(acc))
# The model can be then used either with numpy arrays or other tensors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment