Skip to content

Instantly share code, notes, and snippets.

@nknytk
Last active June 19, 2021 16:21
Show Gist options
  • Save nknytk/90930f8e53a322f0a4f86792893cd5da to your computer and use it in GitHub Desktop.
Save nknytk/90930f8e53a322f0a4f86792893cd5da to your computer and use it in GitHub Desktop.
Raspberry Pi 3BでのONNX Runtime動作速度を32bit版と64bit版で比較する
import os
import model
def main():
check_speed(model.MobileNetV2())
check_speed(model.ResNet50())
check_speed(model.EfficientNetLite4())
check_speed(model.SSDMobileNetV1())
check_speed(model.TinyYOLOV3())
check_speed(model.SimpleConv_6())
check_speed(model.UltraFace320())
def check_speed(model_instance):
inference_times = []
imread_times = []
preprocess_times = []
for f in os.listdir('images'):
result = model_instance.infer(os.path.join('images', f))
inference_times.append(result['inference_msec'])
imread_times.append(result['imread_msec'])
preprocess_times.append(result['preprocess_msec'])
data_length = len(imread_times)
msg = 'model:{} imread:{:.02f}ms preprocessing:{:.02f}ms inference:{:.02f}ms'.format(
model_instance.__class__.__name__,
sum(imread_times) / data_length,
sum(preprocess_times) / data_length,
sum(inference_times) / data_length
)
print(msg)
if __name__ == '__main__':
main()
import os
from time import time
import cv2
import numpy
import onnxruntime
onnxruntime.set_default_logger_severity(3)
class ImageModelBase:
onnx_file_name = ''
imagenet_mean_vec = numpy.array([0.485, 0.456, 0.406])
imagenet_stddev_vec = numpy.array([0.229, 0.224, 0.225])
def __init__(self, input_name: str=None, output_name: str=None):
onnx_file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', self.onnx_file_name))
self.sess = onnxruntime.InferenceSession(onnx_file_path)
self.input_name = self.sess.get_inputs()[0].name if input_name is None else input_name
self.output_name = self.sess.get_outputs()[0].name if output_name is None else output_name
def infer(self, img_path: str) -> numpy.ndarray:
t0 = time()
img = cv2.imread(img_path)
t1 = time()
data_input= self.preprocess(img)
t2 = time()
result = self.sess.run([self.output_name], {self.input_name: data_input})[0]
t3 = time()
return {
'result': result,
'imread_msec': (t1 - t0) * 1000,
'preprocess_msec': (t2 - t1) * 1000,
'inference_msec': (t3 - t2) * 1000
}
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
return img.reshape(1, *img.shape)
def preprocess_imagenet(self, img_data: numpy.ndarray, target_size: tuple=None, resize_mode: str='crop', channel_first: bool=True) -> numpy.ndarray:
if target_size is not None:
img_data = self.resize(img_data, target_size[0], target_size[1], resize_mode)
img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
norm_img_data = numpy.zeros(img_data.shape).astype('float32')
for i in range(img_data.shape[2]):
# for each pixel in each channel, divide the value by 255 to get value between [0, 1] and then normalize
norm_img_data[:,:,i] = (img_data[:,:,i]/255 - self.imagenet_mean_vec[i]) / self.imagenet_stddev_vec[i]
if channel_first:
norm_img_data = norm_img_data.transpose(2, 0, 1)
return norm_img_data
def resize(self, img_data: numpy.ndarray, target_width: int, target_height: int, mode: str) -> numpy.ndarray:
if mode == 'crop':
h, w, c = img_data.shape
aspect_ratio = w / h
target_aspect_ratio = target_width / target_height
if target_aspect_ratio > aspect_ratio:
# リサイズ後が前より横長になる場合、縦横比を保ったままリサイズして上下を切る
resize_height = int(target_width / aspect_ratio)
resized_img = cv2.resize(img_data, (target_width, resize_height))
crop_offset = int((resize_height - target_height) / 2)
return resized_img[crop_offset:crop_offset + target_height, :, :]
else:
# リサイズ後が前より縦長になる場合、縦横比を保ったままリサイズして左右を切る
resize_width = int(target_height * aspect_ratio)
resized_img = cv2.resize(img_data, (resize_width, target_height))
crop_offset = int((resize_width - target_width) / 2)
return resized_img[:, crop_offset:crop_offset + target_width, :]
elif mode == 'letterbox':
h, w, c = img_data.shape
aspect_ratio = w / h
target_aspect_ratio = target_width / target_height
canvas = numpy.ones((target_height, target_width, c), dtype=numpy.uint8) * 127
if target_aspect_ratio > aspect_ratio:
# リサイズ後が前より横長になる場合、上下を固定色で埋める
resize_height = int(target_width / target_aspect_ratio)
resized_img = cv2.resize(img_data, (target_width, resize_height))
fill_offset = int((target_height - resize_height) / 2)
canvas[fill_offset:fill_offset + resize_height,:,:] = resized_img
else:
# リサイズ後が前より縦長になる場合、左右を固定色で埋める
resize_width = int(target_height * target_aspect_ratio)
resized_img = cv2.resize(img_data, (resize_width, target_height))
fill_offset = int((target_width - resize_width) / 2)
canvas[:, fill_offset:fill_offset + resize_width, :] = resized_img
return canvas
else:
return cv2.resize(img_data, (target_width, target_height))
class MobileNetV2(ImageModelBase):
onnx_file_name = 'mobilenetv2-7.onnx'
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
img = self.preprocess_imagenet(img, (224, 224))
return img.reshape(1, *img.shape)
class ResNet50(ImageModelBase):
onnx_file_name = 'resnet50-v2-7.onnx'
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
img = self.preprocess_imagenet(img, (224, 224))
return img.reshape(1, *img.shape)
class EfficientNetLite4(ImageModelBase):
onnx_file_name = 'efficientnet-lite4-11.onnx'
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
img = self.preprocess_imagenet(img, (224, 224), channel_first=False)
return img.reshape(1, *img.shape)
class TinyYOLOV3(ImageModelBase):
onnx_file_name = 'tiny-yolov3-11.onnx'
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
resized_img = self.preprocess_imagenet(img, (416, 416), resize_mode='letterbox')
original_size = numpy.array([img.shape[1], img.shape[0]], dtype=numpy.float32).reshape(1, 2)
return resized_img.reshape(1, *resized_img.shape), original_size
def infer(self, img_path: str) -> numpy.ndarray:
t0 = time()
img = cv2.imread(img_path)
t1 = time()
data_input, image_size = self.preprocess(img)
t2 = time()
result = self.sess.run([self.output_name], {'input_1': data_input, 'image_shape': image_size})[0]
t3 = time()
return {
'result': result,
'imread_msec': (t1 - t0) * 1000,
'preprocess_msec': (t2 - t1) * 1000,
'inference_msec': (t3 - t2) * 1000
}
class SSDMobileNetV1(ImageModelBase):
onnx_file_name = 'ssd_mobilenet_v1_10.onnx'
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img.reshape(1, *img.shape)
class SimpleConv_6(ImageModelBase):
onnx_file_name = 'simple_6_pascal_voc3.onnx'
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
img = cv2.cvtColor(self.resize(img, 224, 224, ''), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
return img.reshape(1, *img.shape).astype(numpy.float32)
class UltraFace320(ImageModelBase):
onnx_file_name = 'version-RFB-320.onnx'
img_mean = numpy.array([127, 127, 127])
def preprocess(self, img: numpy.ndarray) -> numpy.ndarray:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (320, 240))
img = (img - self.img_mean) / 128
img = numpy.transpose(img, [2, 0, 1])
img = numpy.expand_dims(img, axis=0)
return img.astype(numpy.float32)
#!bin/bash
#32bitのRaspberry Pi OSの場合は
#https://github.com/nknytk/LightOD/tree/master/infer
#に従ってOpenCVとONNX Runtimeをインストールする。
#x86_64, aarch64の場合はpipでインストールする。
set -e
# prepare python virtual environment
python3 -m venv .venv
. .venv/bin/activate
pip install --upgrade pip wheel
pip install onnxruntime opencv-python
# download onnx models
mkdir -p models
cd models
wget "https://github.com/onnx/models/raw/master/vision/classification/mobilenet/model/mobilenetv2-7.onnx"
wget "https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.onnx"
wget "https://github.com/onnx/models/raw/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"
wget "https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/tiny-yolov3/model/tiny-yolov3-11.onnx"
wget "https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.onnx"
wget "https://github.com/onnx/models/raw/master/vision/body_analysis/ultraface/models/version-RFB-320.onnx"
wget "https://github.com/nknytk/LightOD/raw/master/detector/trained/simple_6_pascal_voc3.onnx"
cd ..
# download test images
mkdir -p images
cd images
wget -O "0.jpg" "https://farm5.staticflickr.com/4019/4251501904_2816eabf1c_z.jpg"
wget -O "1.jpg" "https://farm6.staticflickr.com/5456/8786671291_13d84ed8f0_z.jpg"
wget -O "2.jpg" "https://farm9.staticflickr.com/8483/8252886800_2703b30c68_z.jpg"
wget -O "3.jpg" "https://farm4.staticflickr.com/3249/3283303380_7b5e042381_z.jpg"
wget -O "4.jpg" "https://farm8.staticflickr.com/7113/7590354298_eecff17e7c_z.jpg"
wget -O "5.jpg" "https://farm4.staticflickr.com/3683/9048189946_e619a054da_z.jpg"
wget -O "6.jpg" "https://farm4.staticflickr.com/3808/10258811356_9d3275304a_z.jpg"
wget -O "7.jpg" "https://farm6.staticflickr.com/5542/9139927453_a7c0aeeef1_z.jpg"
wget -O "8.jpg" "https://farm5.staticflickr.com/4135/4915451963_ba32b51c6e_z.jpg"
wget -O "9.jpg" "https://farm3.staticflickr.com/2312/2223699835_6b722e5987_z.jpg"
$ uname -a
Linux raspberrypi 5.10.17-v7+ #1421 SMP Thu May 27 13:59:01 BST 2021 armv7l GNU/Linux
$ python measure_speed.py
model:MobileNetV2 imread:33.56ms preprocessing:21.88ms inference:282.52ms
model:ResNet50 imread:33.13ms preprocessing:20.00ms inference:1371.63ms
model:EfficientNetLite4 imread:32.17ms preprocessing:20.14ms inference:1243.25ms
model:SSDMobileNetV1 imread:32.10ms preprocessing:1.79ms inference:731.86ms
model:TinyYOLOV3 imread:32.14ms preprocessing:55.47ms inference:885.98ms
model:SimpleConv_6 imread:32.08ms preprocessing:5.58ms inference:46.34ms
model:UltraFace320 imread:31.99ms preprocessing:25.40ms inference:146.10ms
$ uname -a
Linux raspberrypi 5.4.42-v8+ #1319 SMP PREEMPT Wed May 20 14:18:56 BST 2020 aarch64 GNU/Linux
$ python measure_speed.py
model:MobileNetV2 imread:35.25ms preprocessing:24.19ms inference:223.61ms
model:ResNet50 imread:400.81ms preprocessing:17.44ms inference:1372.13ms
model:EfficientNetLite4 imread:26.81ms preprocessing:16.47ms inference:949.22ms
model:SSDMobileNetV1 imread:27.03ms preprocessing:4.38ms inference:642.20ms
model:TinyYOLOV3 imread:26.85ms preprocessing:43.18ms inference:928.19ms
model:SimpleConv_6 imread:26.85ms preprocessing:4.72ms inference:45.00ms
model:UltraFace320 imread:26.82ms preprocessing:19.97ms inference:106.87ms
$ uname -a
Linux hp-envy 5.8.0-55-generic #62~20.04.1-Ubuntu SMP Wed Jun 2 08:55:04 UTC 2021 x86_64 x86_64 x86_64 GNU/Linux
$ python measure_speed.py
model:MobileNetV2 imread:3.98ms preprocessing:0.77ms inference:3.97ms
model:ResNet50 imread:3.38ms preprocessing:0.78ms inference:24.24ms
model:EfficientNetLite4 imread:3.38ms preprocessing:0.77ms inference:11.37ms
model:SSDMobileNetV1 imread:3.39ms preprocessing:1.05ms inference:19.40ms
model:TinyYOLOV3 imread:3.41ms preprocessing:3.21ms inference:17.78ms
model:SimpleConv_6 imread:3.41ms preprocessing:0.22ms inference:1.89ms
model:UltraFace320 imread:3.41ms preprocessing:2.01ms inference:7.26ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment