Skip to content

Instantly share code, notes, and snippets.

@DoctorDinosaur
Last active May 2, 2024 14:50
Show Gist options
  • Save DoctorDinosaur/be495b6065fff29f79ec11306dd89c3b to your computer and use it in GitHub Desktop.
Save DoctorDinosaur/be495b6065fff29f79ec11306dd89c3b to your computer and use it in GitHub Desktop.
Attempting direct inference for a mediapipe object detection model
import os
from pprint import pprint
import numpy as np
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
class_map = {
1: "box",
2: "green_light",
3: "left_arrow",
4: "no_light",
5: "person",
6: "red_light",
7: "right_arrow",
8: "tree",
9: "unknown_arrow",
}
# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path="mediapipe/exported_model/model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Read and preprocess an image
_, input_height, input_width, _ = interpreter.get_input_details()[0]["shape"] # 256, 256
img = tf.io.read_file("label_data/test/downloads_images/1709570998863_90_35.png")
img = tf.io.decode_png(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
# Image is now normalised [0, 1]
original_image = img
resized_img = tf.image.resize(img, (input_height, input_width))
resized_img = resized_img[tf.newaxis, :]
resized_img = tf.cast(resized_img, dtype=tf.float32)
# Shape: (1, 256, 256, 3)
# Range: [0, 1]
signature_fn = interpreter.get_signature_runner("serving_default")
output = signature_fn(inputs=resized_img)
# Shape is (1, num_boxes, 4) for detection_boxes. Add a dimension to the end to make it (1, num_boxes, 1, 4)
output["detection_boxes"] = output["detection_boxes"][:, :, tf.newaxis, :]
""" {'detection_boxes': array([[[ 0.5667457 , -0.21096665, -1.5467784 , -1.8802471 ],
[-0.0475191 , 0.05186775, -2.129962 , -1.7360085 ],
[-0.18870918, 0.20574829, -2.4253688 , -1.2779199 ],
...]],
dtype=float32),
'detection_scores': array([[[0.00132673, 0.00542115, 0.00166283, ..., 0.00414128,
0.00341194, 0.00249445],
[0.00151299, 0.00264928, 0.00144521, ..., 0.00185226,
0.00273689, 0.00230351],
[0.00126078, 0.00119116, 0.00167455, ..., 0.00088799,
0.00250108, 0.00245593],
...]], dtype=float32)} """
# Run non-max suppression on the output
boxes = output["detection_boxes"]
scores = output["detection_scores"]
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
tf.image.combined_non_max_suppression(
boxes,
scores,
max_output_size_per_class=5,
max_total_size=5,
iou_threshold=0.5,
score_threshold=0.5,
)
)
detections = {
"detection_boxes": nmsed_boxes,
"detection_scores": nmsed_scores,
"detection_classes": nmsed_classes,
"num_detections": valid_detections,
}
def visualise(img, detections):
plt.imshow(img)
for i in range(detections["num_detections"][0]):
bbox = detections["detection_boxes"][0][i].numpy()
score = detections["detection_scores"][0][i].numpy()
class_id = detections["detection_classes"][0][i].numpy()
class_name = class_map[class_id]
img_width, img_height = img.shape[1], img.shape[0]
print(img_width, img_height)
print(bbox)
y1, x1, y2, x2 = bbox
origin_x, origin_y = x1 * img_width, y1 * img_height
width, height = (x2 - x1) * img_width, (y2 - y1) * img_height
print(origin_x, origin_y, width, height)
rect = plt.Rectangle(
(origin_x, origin_y),
width,
height,
fill=False,
edgecolor="red",
linewidth=2,
)
plt.gca().add_patch(rect)
plt.text(origin_x, origin_y, f"{class_name} {score:.2f}", color="red")
visualise(original_image, detections)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment