Created
October 5, 2019 09:15
-
-
Save joelbarmettlerUZH/99c4fd40740305b88ee8e46ba150a06d to your computer and use it in GitHub Desktop.
Call tensorflow object detection API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from object_detection.utils import ops as utils_ops | |
import tensorflow as tf | |
def run_inference_for_single_image(image, sess): | |
ops = tf.get_default_graph().get_operations() | |
all_tensor_names = {output.name for op in ops for output in op.outputs} | |
tensor_dict = {} | |
for key in [ | |
'num_detections', 'detection_boxes', 'detection_scores', | |
'detection_classes', 'detection_masks' | |
]: | |
tensor_name = key + ':0' | |
if tensor_name in all_tensor_names: | |
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( | |
tensor_name) | |
if 'detection_masks' in tensor_dict: | |
# The following processing is only for single image | |
detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0]) | |
detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0]) | |
# Reframe is required to translate mask from box coordinates to image coordinates and fit the image size. | |
real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32) | |
detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1]) | |
detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1]) | |
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks( | |
detection_masks, detection_boxes, image.shape[1], image.shape[2]) | |
detection_masks_reframed = tf.cast( | |
tf.greater(detection_masks_reframed, 0.5), tf.uint8) | |
# Follow the convention by adding back the batch dimension | |
tensor_dict['detection_masks'] = tf.expand_dims( | |
detection_masks_reframed, 0) | |
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') | |
# Run inference | |
output_dict = sess.run(tensor_dict, | |
feed_dict={image_tensor: image}) | |
# all outputs are float32 numpy arrays, so convert types as appropriate | |
output_dict['num_detections'] = int(output_dict['num_detections'][0]) | |
output_dict['detection_classes'] = output_dict[ | |
'detection_classes'][0].astype(np.int64) | |
output_dict['detection_boxes'] = output_dict['detection_boxes'][0] | |
output_dict['detection_scores'] = output_dict['detection_scores'][0] | |
if 'detection_masks' in output_dict: | |
output_dict['detection_masks'] = output_dict['detection_masks'][0] | |
return output_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment