Last active
December 12, 2018 01:26
-
-
Save kezunlin/01adb3c752072f36954ad1bb4f935c14 to your computer and use it in GitHub Desktop.
yolov3 inference for linux and window
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
yolo.py for yolov3 | |
support platforms: | |
- windows | |
- linux | |
""" | |
from ctypes import * | |
import random | |
import os | |
import time | |
import cv2 | |
import numpy as np | |
def sample(probs): | |
s = sum(probs) | |
probs = [a/s for a in probs] | |
r = random.uniform(0, 1) | |
for i in range(len(probs)): | |
r = r - probs[i] | |
if r <= 0: | |
return i | |
return len(probs)-1 | |
def c_array(ctype, values): | |
arr = (ctype*len(values))() | |
arr[:] = values | |
return arr | |
class BOX(Structure): | |
_fields_ = [("x", c_float), | |
("y", c_float), | |
("w", c_float), | |
("h", c_float)] | |
class DETECTION(Structure): | |
_fields_ = [("bbox", BOX), | |
("classes", c_int), | |
("prob", POINTER(c_float)), | |
("mask", POINTER(c_float)), | |
("objectness", c_float), | |
("sort_class", c_int)] | |
class IMAGE(Structure): | |
_fields_ = [("w", c_int), | |
("h", c_int), | |
("c", c_int), | |
("data", POINTER(c_float))] | |
class METADATA(Structure): | |
_fields_ = [("classes", c_int), | |
("names", POINTER(c_char_p))] | |
DEBUG = False | |
#================================================== | |
# os.name "posix" for linux, "nt" for windows | |
# NAMES only for windows | |
if os.name == "nt": | |
CLASS_NAMES = ["knife", "KK", "water"] | |
CLASS_THRESHOLD = [0.0, 0.0, 0.05] | |
#lib = CDLL("./cuda80_yolo_cpp_dll.dll", RTLD_GLOBAL) | |
lib = CDLL("./sdk/yolo/cuda80_yolo_cpp_dll.dll", RTLD_GLOBAL) | |
#lib = CDLL("./sdk/yolo/cuda90_yolo_cpp_dll.dll", RTLD_GLOBAL) | |
else: | |
lib = CDLL("./darknet.so", RTLD_GLOBAL) | |
#================================================== | |
lib.network_width.argtypes = [c_void_p] | |
lib.network_width.restype = c_int | |
lib.network_height.argtypes = [c_void_p] | |
lib.network_height.restype = c_int | |
set_gpu = lib.cuda_set_device | |
set_gpu.argtypes = [c_int] | |
make_image = lib.make_image | |
make_image.argtypes = [c_int, c_int, c_int] | |
make_image.restype = IMAGE | |
make_network_boxes = lib.make_network_boxes | |
make_network_boxes.argtypes = [c_void_p] | |
make_network_boxes.restype = POINTER(DETECTION) | |
free_detections = lib.free_detections | |
free_detections.argtypes = [POINTER(DETECTION), c_int] | |
free_ptrs = lib.free_ptrs | |
free_ptrs.argtypes = [POINTER(c_void_p), c_int] | |
reset_rnn = lib.reset_rnn | |
reset_rnn.argtypes = [c_void_p] | |
load_net = lib.load_network | |
load_net.argtypes = [c_char_p, c_char_p, c_int] | |
load_net.restype = c_void_p | |
load_meta = lib.get_metadata | |
lib.get_metadata.argtypes = [c_char_p] | |
lib.get_metadata.restype = METADATA | |
do_nms_obj = lib.do_nms_obj | |
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] | |
do_nms_sort = lib.do_nms_sort | |
do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] | |
free_image = lib.free_image | |
free_image.argtypes = [IMAGE] | |
letterbox_image = lib.letterbox_image | |
letterbox_image.argtypes = [IMAGE, c_int, c_int] | |
letterbox_image.restype = IMAGE | |
load_image_color = lib.load_image_color | |
load_image_color.argtypes = [c_char_p, c_int, c_int] | |
load_image_color.restype = IMAGE | |
resize_image = lib.resize_image | |
resize_image.argtypes = [IMAGE, c_int, c_int] | |
resize_image.restype = IMAGE | |
rgbgr_image = lib.rgbgr_image | |
rgbgr_image.argtypes = [IMAGE] | |
# by default, we use `network_predict_image` on linux. | |
# on windows, we need to use `network_predict`, | |
# because `network_predict_image` will generate error results. | |
# for linux | |
network_predict_image = lib.network_predict_image | |
network_predict_image.argtypes = [c_void_p, IMAGE] | |
network_predict_image.restype = POINTER(c_float) | |
# for windows | |
network_predict = lib.network_predict | |
network_predict.argtypes = [c_void_p, POINTER(c_float)] | |
network_predict.restype = POINTER(c_float) | |
get_network_boxes = lib.get_network_boxes | |
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int), c_int] | |
get_network_boxes.restype = POINTER(DETECTION) | |
def detect_core(net, meta, im, is_free_im, thresh, hier_thresh, nms): | |
""" | |
im: | |
case 1: im = load_image_color(imagefile, 0, 0) | |
case 2: im, arr = opencv_image_to_darknet_image(bgr) | |
""" | |
net_w, net_h = lib.network_width(net), lib.network_height(net) | |
sized = resize_image(im, net_w, net_h) | |
if DEBUG: | |
print(" [kzl]: net width,height,channel =( {0}, {1}, {2})".format(net_w, net_h, 3)) | |
print(" [kzl]: im width,height,channel =( {0}, {1}, {2})".format(im.w, im.h, im.c)) | |
print(" [kzl]: sized width,height =( {0}, {1}, {2})".format(sized.w, sized.h, sized.c)) | |
""" | |
[kzl]: net width,height,channel =( 416, 416, 3) | |
[kzl]: im width,height,channel =( 1280, 859, 3) | |
[kzl]: sized width,height =( 416, 416, 3) | |
""" | |
ptr_map = None | |
relative = 0 # 1 for relative pos | |
nboxes = c_int(0) | |
p_nboxes = pointer(nboxes) | |
letterbox = 0 # 1 for letterbox | |
if os.name == "posix": | |
network_predict_image(net, im) # OK for linux, ERROR for windows | |
else: | |
network_predict(net, sized.data) # on windows, we use network_predict | |
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, ptr_map, relative, p_nboxes, letterbox) | |
nboxes = p_nboxes[0] | |
if DEBUG: | |
print(" [kzl]: nboxes =( {0})".format(nboxes)) | |
if (nms): | |
do_nms_obj(dets, nboxes, meta.classes, nms) | |
res = [] | |
for j in range(nboxes): | |
#print("==============================================") | |
#print("box #",j) | |
for i in range(meta.classes): | |
if os.name == "posix": | |
class_name = meta.names[i] | |
class_threshold = 0 | |
else: | |
# use `NAMES[i]` instead of `meta.names[i]` on windows | |
class_name = CLASS_NAMES[i] | |
class_threshold = CLASS_THRESHOLD[i] | |
prob = dets[j].prob[i] | |
if prob > class_threshold: | |
b = dets[j].bbox | |
res.append((class_name, prob, (b.x, b.y, b.w, b.h))) | |
res = sorted(res, key=lambda x: -x[1]) | |
# free | |
free_detections(dets, nboxes) | |
free_image(sized) | |
if is_free_im: | |
free_image(im) | |
return res | |
def detect_imagefile(net, meta, imagefile, thresh=.5, hier_thresh=.5, nms=.45): | |
im = load_image_color(imagefile, 0, 0) # BGR,hwc,[0,255] === > RGB,chw,[0,1] | |
# image memory allocated in C, so need to free(im) | |
return detect_core(net, meta, im, True, thresh, hier_thresh, nms) | |
def opencv_image_to_darknet_image(bgr): | |
# BGR,hwc,[0,255] === > RGB,chw,[0,1] | |
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
rgb = rgb.transpose(2, 0, 1) # hwc ===> chw | |
c = rgb.shape[0] | |
h = rgb.shape[1] | |
w = rgb.shape[2] | |
arr = np.ascontiguousarray(rgb.flat, dtype=np.float32) / 255.0 # [0-1] | |
data = arr.ctypes.data_as(POINTER(c_float)) | |
im = IMAGE(w, h, c, data) | |
return im, arr # return `arr` to avoid python freeing memory | |
# 20181120, 0.5 ===> 0.01 | |
def detect_imagebuffer(net, meta, bgr, thresh=.01, hier_thresh=.5, nms=.45): | |
im, arr = opencv_image_to_darknet_image(bgr) # BGR,hwc,[0,255] === > RGB,chw,[0,1] | |
# image memory allocated in Python with `arr`, `im` only point to `arr` | |
# so there is no need to free_image(im), because python will free `arr` automatically | |
return detect_core(net, meta, im, False, thresh, hier_thresh, nms) | |
def yolov3(): | |
#net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0) | |
#im = load_image("data/wolf.jpg", 0, 0) | |
#meta = load_meta("cfg/imagenet1k.data") | |
#r = classify(net, meta, im) | |
#print r[:10] | |
net = load_net("cfg/yolov3.cfg", "yolov3.weights", 0) | |
meta = load_meta("cfg/coco.data") | |
r = detect_imagefile(net, meta, "data/dog.jpg") | |
print(r) | |
def xray(): | |
use_imagebuffer = True # for avg cost | |
image_dir = "./test/" | |
image_dir = "/home/kezunlin/git/sdklite/data/reid_data/frames/2/" | |
output_dir = "./result_yolo/2/" | |
os.makedirs(output_dir) | |
image_filename_list = os.listdir(image_dir) | |
total_cost = 0 | |
set_gpu(0) | |
#net = load_net("./xray.cfg", "./xray.weights", 0) | |
#meta = load_meta("./xray.data") | |
#net = load_net("./xray.cfg".encode('ascii'), "./xray.weights".encode('ascii'), 0) | |
#meta = load_meta("./xray.data".encode('ascii')) # meta.names error | |
net = load_net("yolov3.cfg", "yolov3.weights", 0) | |
meta = load_meta("coco.data") | |
image_size = len(image_filename_list) | |
for index, image_filename in enumerate(image_filename_list): | |
image_filepath = os.path.join(image_dir, image_filename) | |
print("image #{} from {}".format(index, image_filepath)) | |
if not use_imagebuffer: | |
print("================detect_imagefile======================") | |
image_filename = image_filepath.encode('ascii') | |
rs = detect_imagefile(net, meta, image_filename, thresh=0.5) | |
else: | |
print("================detect_imagebuffer======================") | |
image_filename = image_filepath # test.jpg | |
bgr = cv2.imread(image_filename) | |
since = time.time() | |
# detect | |
rs = detect_imagebuffer(net, meta, bgr, thresh=0.5) | |
time_elapsed = time.time() - since | |
print("%f real seconds" % time_elapsed) | |
print("len = ", len(rs)) | |
for r in rs: | |
# x1, y1, x2, y2 | |
print(r) # ('knife', 0.9975, (522.5, 482.7, 155.0, 342.8)) | |
class_name = r[0] | |
score = r[1] | |
box4_tuple = r[2] # (cx,cy,w,h) | |
cx = box4_tuple[0] | |
cy = box4_tuple[1] | |
w = box4_tuple[2] | |
h = box4_tuple[3] | |
x = cx - w / 2 | |
y = cy - h / 2 | |
box = [int(x), int(y), int(x + w), int(y + h)] | |
print(box) | |
print(score) | |
print(class_name) | |
bgr_color = (0, 0, 255) | |
cv2.rectangle(bgr, (box[0], box[1]), (box[2], box[3]), bgr_color, 2) | |
cv2.putText(bgr, class_name, (int(x), int(y-5)), cv2.FONT_HERSHEY_SIMPLEX, 1, bgr_color, 2) | |
pad_index = "{0:06d}".format(index) | |
filepath = output_dir + pad_index + "_image_with_boxs.jpg" | |
cv2.imwrite(filepath , bgr) | |
print("saved ",filepath) | |
# skip the first image | |
if index > 0: | |
total_cost += time_elapsed | |
if image_size - 1 > 0: | |
avg_cost = total_cost / ((image_size - 1) * 1.0) | |
msg = "image size ={}, total cost={} s, avg cost ={} ms".format(image_size - 1, int(total_cost), | |
int(avg_cost * 1000)) | |
print(msg) | |
""" | |
======================================================== | |
linux: 20181203: | |
image size =782, total cost=64 s, avg cost =82 ms | |
""" | |
if __name__ == "__main__": | |
#yolov3() | |
xray() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment