Skip to content

Instantly share code, notes, and snippets.

@minakhan01
Created June 21, 2018 17:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save minakhan01/2bda0dd0ff62e50ff39cf998395cf15a to your computer and use it in GitHub Desktop.
Save minakhan01/2bda0dd0ff62e50ff39cf998395cf15a to your computer and use it in GitHub Desktop.
AIY mobile net
#!/usr/bin/env python3
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to run generic MobileNet based classification model."""
import argparse
import time
import io
import sys
from PIL import Image
from PIL import ImageDraw
from picamera import Color
from picamera import PiCamera
from aiy.vision import inference
from aiy.vision.models import utils
def read_labels(label_path):
with open(label_path) as label_file:
return [label.strip() for label in label_file.readlines()]
def get_message(processed_result, threshold, top_k):
if processed_result:
message = 'Detecting:\n %s' % ('\n'.join(processed_result))
else:
message = 'Nothing detected when threshold=%.2f, top_k=%d' % (
threshold, top_k)
return message
def process(result, labels, out_tensor_name, threshold, top_k):
"""Processes inference result and returns labels sorted by confidence."""
# MobileNet based classification model returns one result vector.
assert len(result.tensors) == 1
tensor = result.tensors[out_tensor_name]
probs, shape = tensor.data, tensor.shape
assert shape.depth == len(labels)
pairs = [pair for pair in enumerate(probs) if pair[1] > threshold]
pairs = sorted(pairs, key=lambda pair: pair[1], reverse=True)
pairs = pairs[0:top_k]
return [' %s (%.2f)' % (labels[index], prob) for index, prob in pairs]
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_path',
required=True,
help='Path to converted model file that can run on VisionKit.')
parser.add_argument(
'--label_path',
required=True,
help='Path to label file that corresponds to the model.')
parser.add_argument(
'--input_height', type=int, required=True, help='Input height.')
parser.add_argument(
'--input_width', type=int, required=True, help='Input width.')
parser.add_argument(
'--input_layer', required=True, help='Name of input layer.')
parser.add_argument(
'--output_layer', required=True, help='Name of output layer.')
parser.add_argument(
'--image_inference', action='store_true', default=False)
parser.add_argument(
'--input_image', '-i', dest='input')
parser.add_argument(
'--output_image', '-o', dest='output')
parser.add_argument(
'--num_frames',
type=int,
default=-1,
help='Sets the number of frames to run for, otherwise runs forever.')
parser.add_argument(
'--input_mean', type=float, default=128.0, help='Input mean.')
parser.add_argument(
'--input_std', type=float, default=128.0, help='Input std.')
parser.add_argument(
'--input_depth', type=int, default=3, help='Input depth.')
parser.add_argument(
'--threshold', type=float, default=0.1,
help='Threshold for classification score (from output tensor).')
parser.add_argument(
'--top_k', type=int, default=3, help='Keep at most top_k labels.')
parser.add_argument(
'--preview',
action='store_true',
default=False,
help='Enables camera preview in addition to printing result to terminal.')
parser.add_argument(
'--show_fps',
action='store_true',
default=False,
help='Shows end to end FPS.')
args = parser.parse_args()
model = inference.ModelDescriptor(
name='mobilenet_based_classifier',
input_shape=(1, args.input_height, args.input_width, args.input_depth),
input_normalizer=(args.input_mean, args.input_std),
compute_graph=utils.load_compute_graph(args.model_path))
labels = read_labels(args.label_path)
if args.image_inference:
with inference.ImageInference(model) as image_inference:
image = Image.open(
io.BytesIO(sys.stdin.buffer.read())
if args.input == '-' else args.input)
draw = ImageDraw.Draw(image)
last_time = time.time()
result = image_inference.run(image)
processed_result = process(result, labels, args.output_layer,
args.threshold, args.top_k)
# for i, result in enumerate():
# if i == args.num_frames:
# break
# x, y, width, height = result.bounding_box
# draw.rectangle((x, y, x + width, y + height), outline='red')
message = get_message(processed_result, args.threshold, args.top_k)
print(message)
if args.output:
image.save(args.output)
else:
with PiCamera(sensor_mode=4, resolution=(1640, 1232), framerate=30) as camera:
if args.preview:
camera.start_preview()
with inference.CameraInference(model) as camera_inference:
last_time = time.time()
for i, result in enumerate(camera_inference.run()):
if i == args.num_frames:
break
processed_result = process(result, labels, args.output_layer,
args.threshold, args.top_k)
cur_time = time.time()
fps = 1.0 / (cur_time - last_time)
last_time = cur_time
message = get_message(processed_result, args.threshold, args.top_k)
if args.show_fps:
message += '\nWith %.1f FPS.' % fps
print(message)
if args.preview:
camera.annotate_foreground = Color('black')
camera.annotate_background = Color('white')
# PiCamera text annotation only supports ascii.
camera.annotate_text = '\n %s' % message.encode(
'ascii', 'backslashreplace').decode('ascii')
if args.preview:
camera.stop_preview()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment