Skip to content

Instantly share code, notes, and snippets.

@nuzrub
Last active February 26, 2021 20:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nuzrub/ee3dc19242915278e95cb75014e29083 to your computer and use it in GitHub Desktop.
Save nuzrub/ee3dc19242915278e95cb75014e29083 to your computer and use it in GitHub Desktop.
TensorFlow Object Detection API simplified sample code for inference
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder
center_net_path = './centernet_resnet50_v1_fpn_512x512_coco17_tpu-8/'
pipeline_config = center_net_path + 'pipeline.config'
model_path = center_net_path + 'checkpoint/'
label_map_path = './mscoco_label_map.pbtxt'
image_path = './test.jpg'
# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model = model_builder.build(model_config=model_config, is_training=False)
# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(model_path, 'ckpt-0')).expect_partial()
def get_model_detection_function(model):
@tf.function
def detect_fn(image):
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
detections = model.postprocess(prediction_dict, shapes)
return detections, prediction_dict, tf.reshape(shapes, [-1])
return detect_fn
detect_fn = get_model_detection_function(detection_model)
label_map_path = label_map_path
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(
label_map,
max_num_classes=label_map_util.get_max_label_map_index(label_map),
use_display_name=True)
category_index = label_map_util.create_category_index(categories)
label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)
image = np.array(Image.open(image_path))
input_tensor = tf.convert_to_tensor(np.expand_dims(image, 0), dtype=tf.float32)
detections, predictions_dict, shapes = detect_fn(input_tensor)
label_id_offset = 1
image_np_with_detections = image.copy()
# Use keypoints if available in detections
keypoints, keypoint_scores = None, None
if 'detection_keypoints' in detections:
keypoints = detections['detection_keypoints'][0].numpy()
keypoint_scores = detections['detection_keypoint_scores'][0].numpy()
def get_keypoint_tuples(eval_config):
tuple_list = []
kp_list = eval_config.keypoint_edge
for edge in kp_list:
tuple_list.append((edge.start, edge.end))
return tuple_list
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections,
detections['detection_boxes'][0].numpy(),
(detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
detections['detection_scores'][0].numpy(),
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.30,
agnostic_mode=False,
keypoints=keypoints,
keypoint_scores=keypoint_scores,
keypoint_edges=get_keypoint_tuples(configs['eval_config']))
plt.figure(figsize=(12,16))
plt.imshow(image_np_with_detections)
plt.savefig('./output.png')
plt.show()
@nuzrub
Copy link
Author

nuzrub commented Dec 1, 2020 via email

@IvanGarcia7
Copy link

I'm trying to get the object predictions before inference the class of every region. I think that with your code i can get my goal but i'm not sure

@nuzrub
Copy link
Author

nuzrub commented Dec 1, 2020 via email

@IvanGarcia7
Copy link

I need to get the elements or regions before the class inference, i mean all the regions proposed by the model before to the class detection in every one of them. I need to access to al regions proposed to make some changes or insert new regions and then get the class of every one of them.

@nuzrub
Copy link
Author

nuzrub commented Dec 1, 2020 via email

@IvanGarcia7
Copy link

I will check it. I think that could work.

Thanks for help me nuzrub ; )

@nuzrub
Copy link
Author

nuzrub commented Dec 1, 2020 via email

@IvanGarcia7
Copy link

in this case I am working with centernet but I will take it into account in case I need other alternative models.

@nuzrub
Copy link
Author

nuzrub commented Dec 1, 2020 via email

@Player1-DON
Copy link

First, I would like to say thank you for the easiest to follow tutorial on Medium/TowardsDataScience.

In Line 18 label_map_path = './coco_labelmap.pbtxt' should be '...mscoco_label_map.pbtxt'. The label map file must have been updated.

@nuzrub
Copy link
Author

nuzrub commented Feb 26, 2021

First, I would like to say thank you for the easiest to follow tutorial on Medium/TowardsDataScience.

In Line 18 label_map_path = './coco_labelmap.pbtxt' should be '...mscoco_label_map.pbtxt'. The label map file must have been updated.

Thanks for enjoying it. I fixed the label map path, thanks :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment