Skip to content

Instantly share code, notes, and snippets.

@elgehelge
Forked from peterroelants/mnist_inference.py
Last active January 30, 2018 15:37
Show Gist options
  • Save elgehelge/91114702c0ec0607dafccb8b78ba2735 to your computer and use it in GitHub Desktop.
Save elgehelge/91114702c0ec0607dafccb8b78ba2735 to your computer and use it in GitHub Desktop.
Mnist Inference
"""Script to illustrate inference of a trained tf.estimator.Estimator.
NOTE: This is dependent on mnist_estimator.py which defines the model.
mnist_estimator.py can be found at:
https://gist.github.com/elgehelge/faf200e2b36edfb1b1a77ec65f74ecab
"""
import numpy as np
import skimage.io
import tensorflow as tf
from mnist_estimator import get_estimator
# Set default flags for the output directories
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
name='saved_model_dir', default='./mnist_training',
help='Output directory for model and training stats.')
# MNIST sample images
IMAGE_URLS = [
'https://i.imgur.com/SdYYBDt.png', # 0
'https://i.imgur.com/Wy7mad6.png', # 1
'https://i.imgur.com/nhBZndj.png', # 2
'https://i.imgur.com/V6XeoWZ.png', # 3
'https://i.imgur.com/EdxBM1B.png', # 4
'https://i.imgur.com/zWSDIuV.png', # 5
'https://i.imgur.com/Y28rZho.png', # 6
'https://i.imgur.com/6qsCz2W.png', # 7
'https://i.imgur.com/BVorzCP.png', # 8
'https://i.imgur.com/vt5Edjb.png', # 9
]
def infer(argv=None):
"""Run the inference and print the results to stdout."""
params = tf.contrib.training.HParams() # Empty hyperparameters
# Set the run_config where to load the model from
run_config = tf.contrib.learn.RunConfig()
run_config = run_config.replace(model_dir=FLAGS.saved_model_dir)
# Initialize the estimator and run the prediction
estimator = get_estimator(run_config, params)
result = estimator.predict(input_fn=test_inputs)
for r in result:
print(r)
def test_inputs():
"""Returns training set as Operations.
Returns:
(features, ) Operations that iterate over the test set.
"""
with tf.name_scope('Test_data'):
images = tf.constant(load_images(), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((images,))
# Return as iteration in batches of 1
return dataset.batch(1).make_one_shot_iterator().get_next()
def load_images():
"""Load MNIST sample images from the web and return them in an array.
Returns:
Numpy array of size (10, 28, 28, 1) with MNIST sample images.
"""
images = np.zeros((10, 28, 28, 1))
for idx, url in enumerate(IMAGE_URLS):
images[idx, :, :, 0] = skimage.io.imread(url)
return images
# Run script ##############################################
if __name__ == "__main__":
tf.app.run(main=infer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment