yolov3 inference for linux and window
""" | |
yolo.py for yolov3 | |
support platforms: | |
- windows | |
- linux | |
""" | |
from ctypes import * | |
import random | |
import os | |
import time | |
import cv2 | |
import numpy as np | |
def sample(probs): | |
s = sum(probs) | |
probs = [a/s for a in probs] | |
r = random.uniform(0, 1) | |
for i in range(len(probs)): | |
r = r - probs[i] | |
if r <= 0: | |
return i | |
return len(probs)-1 | |
def c_array(ctype, values): | |
arr = (ctype*len(values))() | |
arr[:] = values | |
return arr | |
class BOX(Structure): | |
_fields_ = [("x", c_float), | |
("y", c_float), | |
("w", c_float), | |
("h", c_float)] | |
class DETECTION(Structure): | |
_fields_ = [("bbox", BOX), | |
("classes", c_int), | |
("prob", POINTER(c_float)), | |
("mask", POINTER(c_float)), | |
("objectness", c_float), | |
("sort_class", c_int)] | |
class IMAGE(Structure): | |
_fields_ = [("w", c_int), | |
("h", c_int), | |
("c", c_int), | |
("data", POINTER(c_float))] | |
class METADATA(Structure): | |
_fields_ = [("classes", c_int), | |
("names", POINTER(c_char_p))] | |
DEBUG = False | |
#================================================== | |
# os.name "posix" for linux, "nt" for windows | |
# NAMES only for windows | |
if os.name == "nt": | |
CLASS_NAMES = ["knife", "KK", "water"] | |
CLASS_THRESHOLD = [0.0, 0.0, 0.05] | |
#lib = CDLL("./cuda80_yolo_cpp_dll.dll", RTLD_GLOBAL) | |
lib = CDLL("./sdk/yolo/cuda80_yolo_cpp_dll.dll", RTLD_GLOBAL) | |
#lib = CDLL("./sdk/yolo/cuda90_yolo_cpp_dll.dll", RTLD_GLOBAL) | |
else: | |
lib = CDLL("./darknet.so", RTLD_GLOBAL) | |
#================================================== | |
lib.network_width.argtypes = [c_void_p] | |
lib.network_width.restype = c_int | |
lib.network_height.argtypes = [c_void_p] | |
lib.network_height.restype = c_int | |
set_gpu = lib.cuda_set_device | |
set_gpu.argtypes = [c_int] | |
make_image = lib.make_image | |
make_image.argtypes = [c_int, c_int, c_int] | |
make_image.restype = IMAGE | |
make_network_boxes = lib.make_network_boxes | |
make_network_boxes.argtypes = [c_void_p] | |
make_network_boxes.restype = POINTER(DETECTION) | |
free_detections = lib.free_detections | |
free_detections.argtypes = [POINTER(DETECTION), c_int] | |
free_ptrs = lib.free_ptrs | |
free_ptrs.argtypes = [POINTER(c_void_p), c_int] | |
reset_rnn = lib.reset_rnn | |
reset_rnn.argtypes = [c_void_p] | |
load_net = lib.load_network | |
load_net.argtypes = [c_char_p, c_char_p, c_int] | |
load_net.restype = c_void_p | |
load_meta = lib.get_metadata | |
lib.get_metadata.argtypes = [c_char_p] | |
lib.get_metadata.restype = METADATA | |
do_nms_obj = lib.do_nms_obj | |
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] | |
do_nms_sort = lib.do_nms_sort | |
do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] | |
free_image = lib.free_image | |
free_image.argtypes = [IMAGE] | |
letterbox_image = lib.letterbox_image | |
letterbox_image.argtypes = [IMAGE, c_int, c_int] | |
letterbox_image.restype = IMAGE | |
load_image_color = lib.load_image_color | |
load_image_color.argtypes = [c_char_p, c_int, c_int] | |
load_image_color.restype = IMAGE | |
resize_image = lib.resize_image | |
resize_image.argtypes = [IMAGE, c_int, c_int] | |
resize_image.restype = IMAGE | |
rgbgr_image = lib.rgbgr_image | |
rgbgr_image.argtypes = [IMAGE] | |
# by default, we use `network_predict_image` on linux. | |
# on windows, we need to use `network_predict`, | |
# because `network_predict_image` will generate error results. | |
# for linux | |
network_predict_image = lib.network_predict_image | |
network_predict_image.argtypes = [c_void_p, IMAGE] | |
network_predict_image.restype = POINTER(c_float) | |
# for windows | |
network_predict = lib.network_predict | |
network_predict.argtypes = [c_void_p, POINTER(c_float)] | |
network_predict.restype = POINTER(c_float) | |
get_network_boxes = lib.get_network_boxes | |
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int), c_int] | |
get_network_boxes.restype = POINTER(DETECTION) | |
def detect_core(net, meta, im, is_free_im, thresh, hier_thresh, nms): | |
""" | |
im: | |
case 1: im = load_image_color(imagefile, 0, 0) | |
case 2: im, arr = opencv_image_to_darknet_image(bgr) | |
""" | |
net_w, net_h = lib.network_width(net), lib.network_height(net) | |
sized = resize_image(im, net_w, net_h) | |
if DEBUG: | |
print(" [kzl]: net width,height,channel =( {0}, {1}, {2})".format(net_w, net_h, 3)) | |
print(" [kzl]: im width,height,channel =( {0}, {1}, {2})".format(im.w, im.h, im.c)) | |
print(" [kzl]: sized width,height =( {0}, {1}, {2})".format(sized.w, sized.h, sized.c)) | |
""" | |
[kzl]: net width,height,channel =( 416, 416, 3) | |
[kzl]: im width,height,channel =( 1280, 859, 3) | |
[kzl]: sized width,height =( 416, 416, 3) | |
""" | |
ptr_map = None | |
relative = 0 # 1 for relative pos | |
nboxes = c_int(0) | |
p_nboxes = pointer(nboxes) | |
letterbox = 0 # 1 for letterbox | |
if os.name == "posix": | |
network_predict_image(net, im) # OK for linux, ERROR for windows | |
else: | |
network_predict(net, sized.data) # on windows, we use network_predict | |
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, ptr_map, relative, p_nboxes, letterbox) | |
nboxes = p_nboxes[0] | |
if DEBUG: | |
print(" [kzl]: nboxes =( {0})".format(nboxes)) | |
if (nms): | |
do_nms_obj(dets, nboxes, meta.classes, nms) | |
res = [] | |
for j in range(nboxes): | |
#print("==============================================") | |
#print("box #",j) | |
for i in range(meta.classes): | |
if os.name == "posix": | |
class_name = meta.names[i] | |
class_threshold = 0 | |
else: | |
# use `NAMES[i]` instead of `meta.names[i]` on windows | |
class_name = CLASS_NAMES[i] | |
class_threshold = CLASS_THRESHOLD[i] | |
prob = dets[j].prob[i] | |
if prob > class_threshold: | |
b = dets[j].bbox | |
res.append((class_name, prob, (b.x, b.y, b.w, b.h))) | |
res = sorted(res, key=lambda x: -x[1]) | |
# free | |
free_detections(dets, nboxes) | |
free_image(sized) | |
if is_free_im: | |
free_image(im) | |
return res | |
def detect_imagefile(net, meta, imagefile, thresh=.5, hier_thresh=.5, nms=.45): | |
im = load_image_color(imagefile, 0, 0) # BGR,hwc,[0,255] === > RGB,chw,[0,1] | |
# image memory allocated in C, so need to free(im) | |
return detect_core(net, meta, im, True, thresh, hier_thresh, nms) | |
def opencv_image_to_darknet_image(bgr): | |
# BGR,hwc,[0,255] === > RGB,chw,[0,1] | |
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
rgb = rgb.transpose(2, 0, 1) # hwc ===> chw | |
c = rgb.shape[0] | |
h = rgb.shape[1] | |
w = rgb.shape[2] | |
arr = np.ascontiguousarray(rgb.flat, dtype=np.float32) / 255.0 # [0-1] | |
data = arr.ctypes.data_as(POINTER(c_float)) | |
im = IMAGE(w, h, c, data) | |
return im, arr # return `arr` to avoid python freeing memory | |
# 20181120, 0.5 ===> 0.01 | |
def detect_imagebuffer(net, meta, bgr, thresh=.01, hier_thresh=.5, nms=.45): | |
im, arr = opencv_image_to_darknet_image(bgr) # BGR,hwc,[0,255] === > RGB,chw,[0,1] | |
# image memory allocated in Python with `arr`, `im` only point to `arr` | |
# so there is no need to free_image(im), because python will free `arr` automatically | |
return detect_core(net, meta, im, False, thresh, hier_thresh, nms) | |
def yolov3(): | |
#net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0) | |
#im = load_image("data/wolf.jpg", 0, 0) | |
#meta = load_meta("cfg/imagenet1k.data") | |
#r = classify(net, meta, im) | |
#print r[:10] | |
net = load_net("cfg/yolov3.cfg", "yolov3.weights", 0) | |
meta = load_meta("cfg/coco.data") | |
r = detect_imagefile(net, meta, "data/dog.jpg") | |
print(r) | |
def xray(): | |
use_imagebuffer = True # for avg cost | |
image_dir = "./test/" | |
image_dir = "/home/kezunlin/git/sdklite/data/reid_data/frames/2/" | |
output_dir = "./result_yolo/2/" | |
os.makedirs(output_dir) | |
image_filename_list = os.listdir(image_dir) | |
total_cost = 0 | |
set_gpu(0) | |
#net = load_net("./xray.cfg", "./xray.weights", 0) | |
#meta = load_meta("./xray.data") | |
#net = load_net("./xray.cfg".encode('ascii'), "./xray.weights".encode('ascii'), 0) | |
#meta = load_meta("./xray.data".encode('ascii')) # meta.names error | |
net = load_net("yolov3.cfg", "yolov3.weights", 0) | |
meta = load_meta("coco.data") | |
image_size = len(image_filename_list) | |
for index, image_filename in enumerate(image_filename_list): | |
image_filepath = os.path.join(image_dir, image_filename) | |
print("image #{} from {}".format(index, image_filepath)) | |
if not use_imagebuffer: | |
print("================detect_imagefile======================") | |
image_filename = image_filepath.encode('ascii') | |
rs = detect_imagefile(net, meta, image_filename, thresh=0.5) | |
else: | |
print("================detect_imagebuffer======================") | |
image_filename = image_filepath # test.jpg | |
bgr = cv2.imread(image_filename) | |
since = time.time() | |
# detect | |
rs = detect_imagebuffer(net, meta, bgr, thresh=0.5) | |
time_elapsed = time.time() - since | |
print("%f real seconds" % time_elapsed) | |
print("len = ", len(rs)) | |
for r in rs: | |
# x1, y1, x2, y2 | |
print(r) # ('knife', 0.9975, (522.5, 482.7, 155.0, 342.8)) | |
class_name = r[0] | |
score = r[1] | |
box4_tuple = r[2] # (cx,cy,w,h) | |
cx = box4_tuple[0] | |
cy = box4_tuple[1] | |
w = box4_tuple[2] | |
h = box4_tuple[3] | |
x = cx - w / 2 | |
y = cy - h / 2 | |
box = [int(x), int(y), int(x + w), int(y + h)] | |
print(box) | |
print(score) | |
print(class_name) | |
bgr_color = (0, 0, 255) | |
cv2.rectangle(bgr, (box[0], box[1]), (box[2], box[3]), bgr_color, 2) | |
cv2.putText(bgr, class_name, (int(x), int(y-5)), cv2.FONT_HERSHEY_SIMPLEX, 1, bgr_color, 2) | |
pad_index = "{0:06d}".format(index) | |
filepath = output_dir + pad_index + "_image_with_boxs.jpg" | |
cv2.imwrite(filepath , bgr) | |
print("saved ",filepath) | |
# skip the first image | |
if index > 0: | |
total_cost += time_elapsed | |
if image_size - 1 > 0: | |
avg_cost = total_cost / ((image_size - 1) * 1.0) | |
msg = "image size ={}, total cost={} s, avg cost ={} ms".format(image_size - 1, int(total_cost), | |
int(avg_cost * 1000)) | |
print(msg) | |
""" | |
======================================================== | |
linux: 20181203: | |
image size =782, total cost=64 s, avg cost =82 ms | |
""" | |
if __name__ == "__main__": | |
#yolov3() | |
xray() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment