-
-
Save DoctorDinosaur/be495b6065fff29f79ec11306dd89c3b to your computer and use it in GitHub Desktop.
Attempting direct inference for a mediapipe object detection model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from pprint import pprint | |
import numpy as np | |
import tensorflow as tf | |
import cv2 | |
import matplotlib.pyplot as plt | |
class_map = { | |
1: "box", | |
2: "green_light", | |
3: "left_arrow", | |
4: "no_light", | |
5: "person", | |
6: "red_light", | |
7: "right_arrow", | |
8: "tree", | |
9: "unknown_arrow", | |
} | |
# Load the TFLite model and allocate tensors | |
interpreter = tf.lite.Interpreter(model_path="mediapipe/exported_model/model.tflite") | |
interpreter.allocate_tensors() | |
# Get input and output tensors | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
# Read and preprocess an image | |
_, input_height, input_width, _ = interpreter.get_input_details()[0]["shape"] # 256, 256 | |
img = tf.io.read_file("label_data/test/downloads_images/1709570998863_90_35.png") | |
img = tf.io.decode_png(img, channels=3) | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
# Image is now normalised [0, 1] | |
original_image = img | |
resized_img = tf.image.resize(img, (input_height, input_width)) | |
resized_img = resized_img[tf.newaxis, :] | |
resized_img = tf.cast(resized_img, dtype=tf.float32) | |
# Shape: (1, 256, 256, 3) | |
# Range: [0, 1] | |
signature_fn = interpreter.get_signature_runner("serving_default") | |
output = signature_fn(inputs=resized_img) | |
# Shape is (1, num_boxes, 4) for detection_boxes. Add a dimension to the end to make it (1, num_boxes, 1, 4) | |
output["detection_boxes"] = output["detection_boxes"][:, :, tf.newaxis, :] | |
""" {'detection_boxes': array([[[ 0.5667457 , -0.21096665, -1.5467784 , -1.8802471 ], | |
[-0.0475191 , 0.05186775, -2.129962 , -1.7360085 ], | |
[-0.18870918, 0.20574829, -2.4253688 , -1.2779199 ], | |
...]], | |
dtype=float32), | |
'detection_scores': array([[[0.00132673, 0.00542115, 0.00166283, ..., 0.00414128, | |
0.00341194, 0.00249445], | |
[0.00151299, 0.00264928, 0.00144521, ..., 0.00185226, | |
0.00273689, 0.00230351], | |
[0.00126078, 0.00119116, 0.00167455, ..., 0.00088799, | |
0.00250108, 0.00245593], | |
...]], dtype=float32)} """ | |
# Run non-max suppression on the output | |
boxes = output["detection_boxes"] | |
scores = output["detection_scores"] | |
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = ( | |
tf.image.combined_non_max_suppression( | |
boxes, | |
scores, | |
max_output_size_per_class=5, | |
max_total_size=5, | |
iou_threshold=0.5, | |
score_threshold=0.5, | |
) | |
) | |
detections = { | |
"detection_boxes": nmsed_boxes, | |
"detection_scores": nmsed_scores, | |
"detection_classes": nmsed_classes, | |
"num_detections": valid_detections, | |
} | |
def visualise(img, detections): | |
plt.imshow(img) | |
for i in range(detections["num_detections"][0]): | |
bbox = detections["detection_boxes"][0][i].numpy() | |
score = detections["detection_scores"][0][i].numpy() | |
class_id = detections["detection_classes"][0][i].numpy() | |
class_name = class_map[class_id] | |
img_width, img_height = img.shape[1], img.shape[0] | |
print(img_width, img_height) | |
print(bbox) | |
y1, x1, y2, x2 = bbox | |
origin_x, origin_y = x1 * img_width, y1 * img_height | |
width, height = (x2 - x1) * img_width, (y2 - y1) * img_height | |
print(origin_x, origin_y, width, height) | |
rect = plt.Rectangle( | |
(origin_x, origin_y), | |
width, | |
height, | |
fill=False, | |
edgecolor="red", | |
linewidth=2, | |
) | |
plt.gca().add_patch(rect) | |
plt.text(origin_x, origin_y, f"{class_name} {score:.2f}", color="red") | |
visualise(original_image, detections) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment