Skip to content

Instantly share code, notes, and snippets.

@runninbear5
Created April 10, 2020 19:31
Show Gist options
  • Save runninbear5/61454ed2b0b8acd20bf81c94e9218e93 to your computer and use it in GitHub Desktop.
Save runninbear5/61454ed2b0b8acd20bf81c94e9218e93 to your computer and use it in GitHub Desktop.
Inference code for running Tensorflow lite model on google coral edge tpu using pi camera
from edgetpu.detection.engine import DetectionEngine
from PIL import Image
from picamera.array import PiRGBArray
from picamera import PiCamera
import argparse
import time
import cv2
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True,
help="path to TensorFlow Lite object detection model")
ap.add_argument("-c", "--confidence", type=float, default=0.3,
help="minimum probability to filter weak detections")
args = vars(ap.parse_args())
print("[INFO] parsing class labels...")
camera = PiCamera()
camera.resolution = (320, 240)
camera.framerate = 30
rawCapture = PiRGBArray(camera, size=(320, 240))
# time.sleep(1)
display_window = cv2.namedWindow("detection")
print("[INFO] loading Coral model...")
model = DetectionEngine(args["model"])
count = 0
start = time.time()
detect = True
for frame in camera.capture_continuous(rawCapture, format="bgr", use_video_port=True):
# grab the frame from the threaded video stream and resize it
# to have a maximum width of 500 pixels
image = cv2.cvtColor(frame.array, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
orig = frame.array
# prepare the frame for object detection by converting (1) it
# from BGR to RGB channel ordering and then (2) from a NumPy
# array to PIL image format
# make predictions on the input frame
if detect:
results = model.detect_with_image(
image, threshold=args["confidence"], keep_aspect_ratio=True, relative_coord=False)
for r in results:
# extract the bounding box and box and predicted class label
box = r.bounding_box.flatten().astype("int")
(startX, startY, endX, endY) = box
label = r.label_id
# draw the bounding box and label on the image
cv2.rectangle(orig, (startX, startY), (endX, endY),
(0, 255, 0), 2)
cv2.putText(orig, str(label), (startX, startY - 15),
cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 0), 2)
y = startY - 15 if startY - 15 > 15 else startY + 15
# show the output frame and wait for a key press
cv2.imshow("detection", orig)
key = cv2.waitKey(1) & 0xFF
count += 1
if (time.time() - start) > 1:
print("FPS: ", count / (time.time() - start))
count = 0
start = time.time()
# print(count / (time.time() - start))
rawCapture.truncate(0)
# if the `q` key was pressed, break from the loop
if key == ord("q"):
break
if key == ord("c"):
detect = not detect
# do a bit of cleanup
cv2.destroyAllWindows()
camera.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment