Skip to content

Instantly share code, notes, and snippets.

@kazuph
Created March 3, 2020 15:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kazuph/da7e92b7314ee8686a9a8f697138828d to your computer and use it in GitHub Desktop.
Save kazuph/da7e92b7314ee8686a9a8f697138828d to your computer and use it in GitHub Desktop.
# Lint as: python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example using TF Lite to detect objects in a given image."""
import argparse
import time
from PIL import Image
from PIL import ImageDraw
import detect
import tflite_runtime.interpreter as tflite
import platform
EDGETPU_SHARED_LIB = {
'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'
}[platform.system()]
import cv2
import numpy as np
cap = cv2.VideoCapture(0)
def load_labels(path, encoding='utf-8'):
"""Loads labels from file (with or without index numbers).
Args:
path: path to label file.
encoding: label file encoding.
Returns:
Dictionary mapping indices to labels.
"""
with open(path, 'r', encoding=encoding) as f:
lines = f.readlines()
if not lines:
return {}
if lines[0].split(' ', maxsplit=1)[0].isdigit():
pairs = [line.split(' ', maxsplit=1) for line in lines]
return {int(index): label.strip() for index, label in pairs}
else:
return {index: line.strip() for index, line in enumerate(lines)}
def make_interpreter(model_file):
model_file, *device = model_file.split('@')
return tflite.Interpreter(
model_path=model_file,
experimental_delegates=[
tflite.load_delegate(EDGETPU_SHARED_LIB,
{'device': device[0]} if device else {})
])
def draw_objects(draw, objs, labels):
"""Draws the bounding box and label for each object."""
for obj in objs:
bbox = obj.bbox
draw.rectangle([(bbox.xmin, bbox.ymin), (bbox.xmax, bbox.ymax)],
outline='red')
draw.text((bbox.xmin + 10, bbox.ymin + 10),
'%s\n%.2f' % (labels.get(obj.id, obj.id), obj.score),
fill='red')
def cv2pil(image):
''' OpenCV型 -> PIL型 '''
new_image = image.copy()
if new_image.ndim == 2: # モノクロ
pass
elif new_image.shape[2] == 3: # カラー
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
elif new_image.shape[2] == 4: # 透過
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
new_image = Image.fromarray(new_image)
return new_image
def pil2cv(image):
''' PIL型 -> OpenCV型 '''
new_image = np.array(image, dtype=np.uint8)
if new_image.ndim == 2: # モノクロ
pass
elif new_image.shape[2] == 3: # カラー
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
elif new_image.shape[2] == 4: # 透過
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
return new_image
def main():
labels = load_labels("models/coco_labels.txt")
interpreter = make_interpreter("models/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite")
interpreter.allocate_tensors()
print('----INFERENCE TIME----')
print('Note: The first inference is slow because it includes',
'loading the model into Edge TPU memory.')
while True:
ret, frame = cap.read()
image = cv2pil(frame)
scale = detect.set_input(interpreter, image.size,
lambda size: image.resize(size, Image.ANTIALIAS))
start = time.perf_counter()
interpreter.invoke()
inference_time = time.perf_counter() - start
objs = detect.get_output(interpreter, 0.4, scale)
print('%.2f ms' % (inference_time * 1000))
print('-------RESULTS--------')
if not objs:
print('No objects detected')
for obj in objs:
print(labels.get(obj.id, obj.id))
# print(' id: ', obj.id)
print('score: ', obj.score)
# print(' bbox: ', obj.bbox)
image = image.convert('RGB')
draw_objects(ImageDraw.Draw(image), objs, labels)
# cv2.putText(image, inference_time, (0,50), cv2.FONT_HERSHEY_PLAIN, 3, (0, 255,0), 3, cv2.LINE_AA)
cv2.imshow('Frame', pil2cv(image))
k = cv2.waitKey(1)
if k == 27:
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment