Last active
July 3, 2019 06:33
-
-
Save kwotsin/292eb12600be02b75bf69ff8010d07ce to your computer and use it in GitHub Desktop.
To test the inference speed of a mobilenet trained on the flowers dataset for a while (training is not done rigorously).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from scipy.misc import imread, imresize | |
from mobilenet_preprocessing import preprocess_for_eval | |
import numpy as np | |
from scipy.misc import imread, imresize | |
import matplotlib.pyplot as plt | |
import time | |
plt.style.use('ggplot') | |
start_time = time.time() | |
#======DEFINE SOME ARGUMENTS=========== | |
flags = tf.app.flags | |
flags.DEFINE_boolean('quantize', False, 'Whether or not to use the quantized model. The original frozen model will be used by default.') | |
flags.DEFINE_boolean('preprocess', True, 'Choose whether to preprocess the image before predicting from the graph.') | |
FLAGS = flags.FLAGS | |
image_size = 299 | |
#==============SOME CODE================= | |
image = './dandelion_unseen.jpg' | |
def main(): | |
if FLAGS.preprocess: | |
img = image | |
else: | |
img = imread(image) | |
img = imresize(img, (image_size,image_size,3)) | |
img = img.astype(np.float32) | |
img = np.expand_dims(img, 0) | |
labels_dict = {0:'daisy', 1:'dandelion',2:'roses', 3:'sunflowers', 4:'tulips'} | |
#Define the filename of the frozen graph | |
if FLAGS.quantize: | |
graph_filename = "./quantized_model_mobilenet.pb" | |
else: | |
graph_filename = "./frozen_model_mobilenet.pb" | |
#Create a graph def object to read the graph | |
with tf.gfile.GFile(graph_filename, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
#Construct the graph and import the graph from graphdef | |
with tf.Graph().as_default() as graph: | |
tf.import_graph_def(graph_def) | |
#Converts an image into a tensor for preprocessing then throw it out as a numpy array later | |
if FLAGS.preprocess: | |
img = tf.convert_to_tensor(img) | |
img = tf.read_file(img) | |
img = tf.image.decode_jpeg(img, channels = 3) | |
preprocessed_img = preprocess_for_eval(img, image_size, image_size) | |
preprocessed_img = tf.expand_dims(preprocessed_img, 0) | |
#We define the input and output node we will feed in | |
input_node = graph.get_tensor_by_name('import/Placeholder_only:0') | |
output_node = graph.get_tensor_by_name('import/MobileNet/Predictions/Softmax:0') | |
with tf.Session() as sess: | |
# obtain the preprocessed image first | |
if FLAGS.preprocess: | |
img = sess.run(preprocessed_img) | |
predictions = sess.run(output_node, feed_dict = {input_node: img})[0] | |
top_5_predictions = predictions.argsort()[-5:][::-1] | |
top_5_probabilities = predictions[top_5_predictions] | |
prediction_names = [labels_dict[i] for i in top_5_predictions] | |
for i in xrange(len(prediction_names)): | |
print 'Prediction: %s, Probability: %s \n' %(prediction_names[i], top_5_probabilities[i]) | |
print 'RUN TIME: %s' %(time.time() - start_time) | |
#======================SHOWING IMAGE====================== | |
#Finally show the image you predicted with the probability. | |
img_plot = plt.imshow(imresize(imread(image), (image_size,image_size,3))) | |
if FLAGS.quantize: | |
text = "QUANTIZED GRAPH\nPrediction: %s (Probability: %.3f)" % (prediction_names[0], top_5_probabilities[0]) | |
else: | |
text = "FROZEN GRAPH\nPrediction: %s (Probability: %.3f)" % (prediction_names[0], top_5_probabilities[0]) | |
#Describe the prediction | |
plt.title(text) | |
#Remove the ticks | |
img_plot.axes.get_xaxis().set_ticks([]) | |
img_plot.axes.get_yaxis().set_ticks([]) | |
plt.show() | |
if __name__=='__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment