Created
October 11, 2019 20:28
-
-
Save omarabid59/240dabd4540377ac2acd89f83bac09e8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import cv2 | |
import numpy as np | |
def load_image(image_name, image_height=416, image_width=416): | |
image = cv2.imread(image_name) | |
image = cv2.resize(image, (image_height, image_width))[:,:,::-1]/255. | |
image_exp = np.expand_dims(image, axis=0) | |
return image_exp | |
def load_model(PATH_TO_FROZEN_GRAPH): | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
od_graph_def = tf.GraphDef() | |
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid: | |
serialized_graph = fid.read() | |
od_graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(od_graph_def, name='') | |
return detection_graph | |
def run_inference(image_exp): | |
with graph.as_default(): | |
with tf.Session() as sess: | |
# Get handles to input and output tensors | |
ops = tf.get_default_graph().get_operations() | |
all_tensor_names = {output.name for op in ops for output in op.outputs} | |
tensor_dict = {} | |
for key in ['output/box_output', | |
'output/score_output', | |
'output/label_output', | |
'input_data']: | |
tensor_name = key + ':0' | |
if tensor_name in all_tensor_names: | |
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( | |
tensor_name) | |
image_tensor = tf.get_default_graph().get_tensor_by_name('input_data:0') | |
# Run inference | |
output_dict = sess.run(tensor_dict, | |
feed_dict={image_tensor: image_exp}) | |
bbs = np.expand_dims(output_dict['output/box_output'], axis=0) | |
scores = np.expand_dims(output_dict['output/score_output'], axis=0) | |
classes = np.expand_dims(output_dict['output/label_output'], axis=0) | |
return bbs, scores, classes | |
def main(): | |
# Specify model file here | |
model_file = '/tmp/yolov3/exp_10/frozen_model.pb' | |
graph = load_model(model_file) | |
image_exp = load_image('test_image.jpg') | |
bbs, scores, classes = run_inference(image_exp) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment