Skip to content

Instantly share code, notes, and snippets.

@kwotsin
Last active July 3, 2019 06:33
Show Gist options
  • Save kwotsin/292eb12600be02b75bf69ff8010d07ce to your computer and use it in GitHub Desktop.
Save kwotsin/292eb12600be02b75bf69ff8010d07ce to your computer and use it in GitHub Desktop.
To test the inference speed of a mobilenet trained on the flowers dataset for a while (training is not done rigorously).
import tensorflow as tf
from scipy.misc import imread, imresize
from mobilenet_preprocessing import preprocess_for_eval
import numpy as np
from scipy.misc import imread, imresize
import matplotlib.pyplot as plt
import time
plt.style.use('ggplot')
start_time = time.time()
#======DEFINE SOME ARGUMENTS===========
flags = tf.app.flags
flags.DEFINE_boolean('quantize', False, 'Whether or not to use the quantized model. The original frozen model will be used by default.')
flags.DEFINE_boolean('preprocess', True, 'Choose whether to preprocess the image before predicting from the graph.')
FLAGS = flags.FLAGS
image_size = 299
#==============SOME CODE=================
image = './dandelion_unseen.jpg'
def main():
if FLAGS.preprocess:
img = image
else:
img = imread(image)
img = imresize(img, (image_size,image_size,3))
img = img.astype(np.float32)
img = np.expand_dims(img, 0)
labels_dict = {0:'daisy', 1:'dandelion',2:'roses', 3:'sunflowers', 4:'tulips'}
#Define the filename of the frozen graph
if FLAGS.quantize:
graph_filename = "./quantized_model_mobilenet.pb"
else:
graph_filename = "./frozen_model_mobilenet.pb"
#Create a graph def object to read the graph
with tf.gfile.GFile(graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#Construct the graph and import the graph from graphdef
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
#Converts an image into a tensor for preprocessing then throw it out as a numpy array later
if FLAGS.preprocess:
img = tf.convert_to_tensor(img)
img = tf.read_file(img)
img = tf.image.decode_jpeg(img, channels = 3)
preprocessed_img = preprocess_for_eval(img, image_size, image_size)
preprocessed_img = tf.expand_dims(preprocessed_img, 0)
#We define the input and output node we will feed in
input_node = graph.get_tensor_by_name('import/Placeholder_only:0')
output_node = graph.get_tensor_by_name('import/MobileNet/Predictions/Softmax:0')
with tf.Session() as sess:
# obtain the preprocessed image first
if FLAGS.preprocess:
img = sess.run(preprocessed_img)
predictions = sess.run(output_node, feed_dict = {input_node: img})[0]
top_5_predictions = predictions.argsort()[-5:][::-1]
top_5_probabilities = predictions[top_5_predictions]
prediction_names = [labels_dict[i] for i in top_5_predictions]
for i in xrange(len(prediction_names)):
print 'Prediction: %s, Probability: %s \n' %(prediction_names[i], top_5_probabilities[i])
print 'RUN TIME: %s' %(time.time() - start_time)
#======================SHOWING IMAGE======================
#Finally show the image you predicted with the probability.
img_plot = plt.imshow(imresize(imread(image), (image_size,image_size,3)))
if FLAGS.quantize:
text = "QUANTIZED GRAPH\nPrediction: %s (Probability: %.3f)" % (prediction_names[0], top_5_probabilities[0])
else:
text = "FROZEN GRAPH\nPrediction: %s (Probability: %.3f)" % (prediction_names[0], top_5_probabilities[0])
#Describe the prediction
plt.title(text)
#Remove the ticks
img_plot.axes.get_xaxis().set_ticks([])
img_plot.axes.get_yaxis().set_ticks([])
plt.show()
if __name__=='__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment