Skip to content

Instantly share code, notes, and snippets.

@vishal-keshav
Created March 19, 2019 10:32
Show Gist options
  • Save vishal-keshav/a0a1c0b526a9fd3f0bf6356cac88a23d to your computer and use it in GitHub Desktop.
Save vishal-keshav/a0a1c0b526a9fd3f0bf6356cac88a23d to your computer and use it in GitHub Desktop.
How to infer an image classification with pre-trained frozen tensorflow pb file
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
from imagenet_classes import class_names
from scipy.misc import imread, imresize
dir_name = 'mobilenet_v1_1.0_224'
with tf.Graph().as_default() as graph:
with tf.Session() as sess:
with gfile.FastGFile(dir_name + "/mobilenet_v1_1.0_224_frozen.pb") as f:
file = 'file1.jpg'
input = imread(file, mode='RGB')
input = imresize(input, (224, 224)).reshape(1, 224, 224, 3).astype(float)
input/=127.5
input-=1.
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, input_map=None, return_elements=None,
name="", op_dict=None, producer_op_list=None)
for op in graph.get_operations():
print("Operation Name :" + op.name)
print("Tensor Stats :" + str(op.values()))
l_input = graph.get_tensor_by_name('input:0')
intermediate = graph.get_tensor_by_name('MobilenetV1/MobilenetV1/Conv2d_0/Relu6:0')
l_output = graph.get_tensor_by_name('MobilenetV1/Predictions/Reshape_1:0')
tf.global_variables_initializer()
inter_out = sess.run(intermediate, feed_dict = {l_input : input})
print(inter_out)
op_prob = sess.run(l_output, feed_dict = {l_input : input})
preds = (np.argsort(op_prob[0])[::-1])[0:5]
for p in preds:
print(class_names[p-1], op_prob[0][p])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment