Skip to content

Instantly share code, notes, and snippets.

@peterroelants
Created September 5, 2017 17:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save peterroelants/3a490905f5b022fea66e0553af51abb8 to your computer and use it in GitHub Desktop.
Save peterroelants/3a490905f5b022fea66e0553af51abb8 to your computer and use it in GitHub Desktop.
Mnist Inference
"""Script to illustrate inference of a trained tf.estimator.Estimator.
NOTE: This is dependent on mnist_estimator.py which defines the model.
mnist_estimator.py can be found at:
https://gist.github.com/peterroelants/9956ec93a07ca4e9ba5bc415b014bcca
"""
import numpy as np
import skimage.io
import tensorflow as tf
from mnist_estimator import get_estimator
# Set default flags for the output directories
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
flag_name='saved_model_dir', default_value='./mnist_training',
docstring='Output directory for model and training stats.')
# MNIST sample images
IMAGE_URLS = [
'https://i.imgur.com/SdYYBDt.png', # 0
'https://i.imgur.com/Wy7mad6.png', # 1
'https://i.imgur.com/nhBZndj.png', # 2
'https://i.imgur.com/V6XeoWZ.png', # 3
'https://i.imgur.com/EdxBM1B.png', # 4
'https://i.imgur.com/zWSDIuV.png', # 5
'https://i.imgur.com/Y28rZho.png', # 6
'https://i.imgur.com/6qsCz2W.png', # 7
'https://i.imgur.com/BVorzCP.png', # 8
'https://i.imgur.com/vt5Edjb.png', # 9
]
def infer(argv=None):
"""Run the inference and print the results to stdout."""
params = tf.contrib.training.HParams() # Empty hyperparameters
# Set the run_config where to load the model from
run_config = tf.contrib.learn.RunConfig()
run_config = run_config.replace(model_dir=FLAGS.saved_model_dir)
# Initialize the estimator and run the prediction
estimator = get_estimator(run_config, params)
result = estimator.predict(input_fn=test_inputs)
for r in result:
print(r)
def test_inputs():
"""Returns training set as Operations.
Returns:
(features, ) Operations that iterate over the test set.
"""
with tf.name_scope('Test_data'):
images = tf.constant(load_images(), dtype=np.float32)
dataset = tf.contrib.data.Dataset.from_tensor_slices((images,))
# Return as iteration in batches of 1
return dataset.batch(1).make_one_shot_iterator().get_next()
def load_images():
"""Load MNIST sample images from the web and return them in an array.
Returns:
Numpy array of size (10, 28, 28, 1) with MNIST sample images.
"""
images = np.zeros((10, 28, 28, 1))
for idx, url in enumerate(IMAGE_URLS):
images[idx, :, :, 0] = skimage.io.imread(url)
return images
# Run script ##############################################
if __name__ == "__main__":
tf.app.run(main=infer)
@rzou15
Copy link

rzou15 commented Jan 14, 2020

line 56 should remove the comma, using (image) instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment