-
-
Save elgehelge/91114702c0ec0607dafccb8b78ba2735 to your computer and use it in GitHub Desktop.
Mnist Inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script to illustrate inference of a trained tf.estimator.Estimator. | |
NOTE: This is dependent on mnist_estimator.py which defines the model. | |
mnist_estimator.py can be found at: | |
https://gist.github.com/elgehelge/faf200e2b36edfb1b1a77ec65f74ecab | |
""" | |
import numpy as np | |
import skimage.io | |
import tensorflow as tf | |
from mnist_estimator import get_estimator | |
# Set default flags for the output directories | |
FLAGS = tf.app.flags.FLAGS | |
tf.app.flags.DEFINE_string( | |
name='saved_model_dir', default='./mnist_training', | |
help='Output directory for model and training stats.') | |
# MNIST sample images | |
IMAGE_URLS = [ | |
'https://i.imgur.com/SdYYBDt.png', # 0 | |
'https://i.imgur.com/Wy7mad6.png', # 1 | |
'https://i.imgur.com/nhBZndj.png', # 2 | |
'https://i.imgur.com/V6XeoWZ.png', # 3 | |
'https://i.imgur.com/EdxBM1B.png', # 4 | |
'https://i.imgur.com/zWSDIuV.png', # 5 | |
'https://i.imgur.com/Y28rZho.png', # 6 | |
'https://i.imgur.com/6qsCz2W.png', # 7 | |
'https://i.imgur.com/BVorzCP.png', # 8 | |
'https://i.imgur.com/vt5Edjb.png', # 9 | |
] | |
def infer(argv=None): | |
"""Run the inference and print the results to stdout.""" | |
params = tf.contrib.training.HParams() # Empty hyperparameters | |
# Set the run_config where to load the model from | |
run_config = tf.contrib.learn.RunConfig() | |
run_config = run_config.replace(model_dir=FLAGS.saved_model_dir) | |
# Initialize the estimator and run the prediction | |
estimator = get_estimator(run_config, params) | |
result = estimator.predict(input_fn=test_inputs) | |
for r in result: | |
print(r) | |
def test_inputs(): | |
"""Returns training set as Operations. | |
Returns: | |
(features, ) Operations that iterate over the test set. | |
""" | |
with tf.name_scope('Test_data'): | |
images = tf.constant(load_images(), dtype=np.float32) | |
dataset = tf.data.Dataset.from_tensor_slices((images,)) | |
# Return as iteration in batches of 1 | |
return dataset.batch(1).make_one_shot_iterator().get_next() | |
def load_images(): | |
"""Load MNIST sample images from the web and return them in an array. | |
Returns: | |
Numpy array of size (10, 28, 28, 1) with MNIST sample images. | |
""" | |
images = np.zeros((10, 28, 28, 1)) | |
for idx, url in enumerate(IMAGE_URLS): | |
images[idx, :, :, 0] = skimage.io.imread(url) | |
return images | |
# Run script ############################################## | |
if __name__ == "__main__": | |
tf.app.run(main=infer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment