Skip to content

Instantly share code, notes, and snippets.

@tawnkramer
Created March 8, 2019 06:07
Show Gist options
  • Save tawnkramer/c0a6f96be645cd000215ecddef17e217 to your computer and use it in GitHub Desktop.
Save tawnkramer/c0a6f96be645cd000215ecddef17e217 to your computer and use it in GitHub Desktop.
"""A demo to classify opencv camera stream with google coral tpu device."""
import argparse
import io
import time
import numpy as np
import cv2
import edgetpu.classification.engine
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', help='File path of Tflite model.', required=True)
parser.add_argument(
'--label', help='File path of label file.', required=True)
args = parser.parse_args()
with open(args.label, 'r') as f:
pairs = (l.strip().split(maxsplit=1) for l in f.readlines())
labels = dict((int(k), v) for k, v in pairs)
engine = edgetpu.classification.engine.ClassificationEngine(args.model)
camera = cv2.VideoCapture(0)
if camera:
camera.set(3, 640)
camera.set(4, 480)
font = cv2.FONT_HERSHEY_SIMPLEX
bottomLeftCornerOfText = (10, 470)
fontScale = 0.6
fontColor = (255,255,255)
lineType = 2
annotate_text = ""
annotate_text_time = time.time()
time_to_show_prediction = 3.0
min_confidence = 0.2
_, width, height, channels = engine.get_input_tensor_shape()
try:
while True:
if not camera.isOpened():
continue
ret, img = camera.read()
if not ret:
continue
input = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input = cv2.resize(input, (width, height))
input = input.reshape((width * height * channels))
start_ms = time.time()
results = engine.ClassifyWithInputTensor(input, top_k=1)
elapsed_ms = time.time() - start_ms
if results and\
results[0][1] > min_confidence and\
time.time() - annotate_text_time > time_to_show_prediction:
annotate_text = "%s %.2f %.2fms" % (
labels[results[0][0]], results[0][1], elapsed_ms*1000.0)
annotate_text_time = time.time()
cv2.putText(img, annotate_text,
bottomLeftCornerOfText,
font,
fontScale,
fontColor,
lineType)
cv2.imshow('frame', img)
cv2.waitKey(1)
finally:
camera.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment