Skip to content

Instantly share code, notes, and snippets.

@kezunlin
Last active December 12, 2018 01:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kezunlin/01adb3c752072f36954ad1bb4f935c14 to your computer and use it in GitHub Desktop.
Save kezunlin/01adb3c752072f36954ad1bb4f935c14 to your computer and use it in GitHub Desktop.
yolov3 inference for linux and window
"""
yolo.py for yolov3
support platforms:
- windows
- linux
"""
from ctypes import *
import random
import os
import time
import cv2
import numpy as np
def sample(probs):
s = sum(probs)
probs = [a/s for a in probs]
r = random.uniform(0, 1)
for i in range(len(probs)):
r = r - probs[i]
if r <= 0:
return i
return len(probs)-1
def c_array(ctype, values):
arr = (ctype*len(values))()
arr[:] = values
return arr
class BOX(Structure):
_fields_ = [("x", c_float),
("y", c_float),
("w", c_float),
("h", c_float)]
class DETECTION(Structure):
_fields_ = [("bbox", BOX),
("classes", c_int),
("prob", POINTER(c_float)),
("mask", POINTER(c_float)),
("objectness", c_float),
("sort_class", c_int)]
class IMAGE(Structure):
_fields_ = [("w", c_int),
("h", c_int),
("c", c_int),
("data", POINTER(c_float))]
class METADATA(Structure):
_fields_ = [("classes", c_int),
("names", POINTER(c_char_p))]
DEBUG = False
#==================================================
# os.name "posix" for linux, "nt" for windows
# NAMES only for windows
if os.name == "nt":
CLASS_NAMES = ["knife", "KK", "water"]
CLASS_THRESHOLD = [0.0, 0.0, 0.05]
#lib = CDLL("./cuda80_yolo_cpp_dll.dll", RTLD_GLOBAL)
lib = CDLL("./sdk/yolo/cuda80_yolo_cpp_dll.dll", RTLD_GLOBAL)
#lib = CDLL("./sdk/yolo/cuda90_yolo_cpp_dll.dll", RTLD_GLOBAL)
else:
lib = CDLL("./darknet.so", RTLD_GLOBAL)
#==================================================
lib.network_width.argtypes = [c_void_p]
lib.network_width.restype = c_int
lib.network_height.argtypes = [c_void_p]
lib.network_height.restype = c_int
set_gpu = lib.cuda_set_device
set_gpu.argtypes = [c_int]
make_image = lib.make_image
make_image.argtypes = [c_int, c_int, c_int]
make_image.restype = IMAGE
make_network_boxes = lib.make_network_boxes
make_network_boxes.argtypes = [c_void_p]
make_network_boxes.restype = POINTER(DETECTION)
free_detections = lib.free_detections
free_detections.argtypes = [POINTER(DETECTION), c_int]
free_ptrs = lib.free_ptrs
free_ptrs.argtypes = [POINTER(c_void_p), c_int]
reset_rnn = lib.reset_rnn
reset_rnn.argtypes = [c_void_p]
load_net = lib.load_network
load_net.argtypes = [c_char_p, c_char_p, c_int]
load_net.restype = c_void_p
load_meta = lib.get_metadata
lib.get_metadata.argtypes = [c_char_p]
lib.get_metadata.restype = METADATA
do_nms_obj = lib.do_nms_obj
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
do_nms_sort = lib.do_nms_sort
do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
free_image = lib.free_image
free_image.argtypes = [IMAGE]
letterbox_image = lib.letterbox_image
letterbox_image.argtypes = [IMAGE, c_int, c_int]
letterbox_image.restype = IMAGE
load_image_color = lib.load_image_color
load_image_color.argtypes = [c_char_p, c_int, c_int]
load_image_color.restype = IMAGE
resize_image = lib.resize_image
resize_image.argtypes = [IMAGE, c_int, c_int]
resize_image.restype = IMAGE
rgbgr_image = lib.rgbgr_image
rgbgr_image.argtypes = [IMAGE]
# by default, we use `network_predict_image` on linux.
# on windows, we need to use `network_predict`,
# because `network_predict_image` will generate error results.
# for linux
network_predict_image = lib.network_predict_image
network_predict_image.argtypes = [c_void_p, IMAGE]
network_predict_image.restype = POINTER(c_float)
# for windows
network_predict = lib.network_predict
network_predict.argtypes = [c_void_p, POINTER(c_float)]
network_predict.restype = POINTER(c_float)
get_network_boxes = lib.get_network_boxes
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int), c_int]
get_network_boxes.restype = POINTER(DETECTION)
def detect_core(net, meta, im, is_free_im, thresh, hier_thresh, nms):
"""
im:
case 1: im = load_image_color(imagefile, 0, 0)
case 2: im, arr = opencv_image_to_darknet_image(bgr)
"""
net_w, net_h = lib.network_width(net), lib.network_height(net)
sized = resize_image(im, net_w, net_h)
if DEBUG:
print(" [kzl]: net width,height,channel =( {0}, {1}, {2})".format(net_w, net_h, 3))
print(" [kzl]: im width,height,channel =( {0}, {1}, {2})".format(im.w, im.h, im.c))
print(" [kzl]: sized width,height =( {0}, {1}, {2})".format(sized.w, sized.h, sized.c))
"""
[kzl]: net width,height,channel =( 416, 416, 3)
[kzl]: im width,height,channel =( 1280, 859, 3)
[kzl]: sized width,height =( 416, 416, 3)
"""
ptr_map = None
relative = 0 # 1 for relative pos
nboxes = c_int(0)
p_nboxes = pointer(nboxes)
letterbox = 0 # 1 for letterbox
if os.name == "posix":
network_predict_image(net, im) # OK for linux, ERROR for windows
else:
network_predict(net, sized.data) # on windows, we use network_predict
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, ptr_map, relative, p_nboxes, letterbox)
nboxes = p_nboxes[0]
if DEBUG:
print(" [kzl]: nboxes =( {0})".format(nboxes))
if (nms):
do_nms_obj(dets, nboxes, meta.classes, nms)
res = []
for j in range(nboxes):
#print("==============================================")
#print("box #",j)
for i in range(meta.classes):
if os.name == "posix":
class_name = meta.names[i]
class_threshold = 0
else:
# use `NAMES[i]` instead of `meta.names[i]` on windows
class_name = CLASS_NAMES[i]
class_threshold = CLASS_THRESHOLD[i]
prob = dets[j].prob[i]
if prob > class_threshold:
b = dets[j].bbox
res.append((class_name, prob, (b.x, b.y, b.w, b.h)))
res = sorted(res, key=lambda x: -x[1])
# free
free_detections(dets, nboxes)
free_image(sized)
if is_free_im:
free_image(im)
return res
def detect_imagefile(net, meta, imagefile, thresh=.5, hier_thresh=.5, nms=.45):
im = load_image_color(imagefile, 0, 0) # BGR,hwc,[0,255] === > RGB,chw,[0,1]
# image memory allocated in C, so need to free(im)
return detect_core(net, meta, im, True, thresh, hier_thresh, nms)
def opencv_image_to_darknet_image(bgr):
# BGR,hwc,[0,255] === > RGB,chw,[0,1]
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
rgb = rgb.transpose(2, 0, 1) # hwc ===> chw
c = rgb.shape[0]
h = rgb.shape[1]
w = rgb.shape[2]
arr = np.ascontiguousarray(rgb.flat, dtype=np.float32) / 255.0 # [0-1]
data = arr.ctypes.data_as(POINTER(c_float))
im = IMAGE(w, h, c, data)
return im, arr # return `arr` to avoid python freeing memory
# 20181120, 0.5 ===> 0.01
def detect_imagebuffer(net, meta, bgr, thresh=.01, hier_thresh=.5, nms=.45):
im, arr = opencv_image_to_darknet_image(bgr) # BGR,hwc,[0,255] === > RGB,chw,[0,1]
# image memory allocated in Python with `arr`, `im` only point to `arr`
# so there is no need to free_image(im), because python will free `arr` automatically
return detect_core(net, meta, im, False, thresh, hier_thresh, nms)
def yolov3():
#net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0)
#im = load_image("data/wolf.jpg", 0, 0)
#meta = load_meta("cfg/imagenet1k.data")
#r = classify(net, meta, im)
#print r[:10]
net = load_net("cfg/yolov3.cfg", "yolov3.weights", 0)
meta = load_meta("cfg/coco.data")
r = detect_imagefile(net, meta, "data/dog.jpg")
print(r)
def xray():
use_imagebuffer = True # for avg cost
image_dir = "./test/"
image_dir = "/home/kezunlin/git/sdklite/data/reid_data/frames/2/"
output_dir = "./result_yolo/2/"
os.makedirs(output_dir)
image_filename_list = os.listdir(image_dir)
total_cost = 0
set_gpu(0)
#net = load_net("./xray.cfg", "./xray.weights", 0)
#meta = load_meta("./xray.data")
#net = load_net("./xray.cfg".encode('ascii'), "./xray.weights".encode('ascii'), 0)
#meta = load_meta("./xray.data".encode('ascii')) # meta.names error
net = load_net("yolov3.cfg", "yolov3.weights", 0)
meta = load_meta("coco.data")
image_size = len(image_filename_list)
for index, image_filename in enumerate(image_filename_list):
image_filepath = os.path.join(image_dir, image_filename)
print("image #{} from {}".format(index, image_filepath))
if not use_imagebuffer:
print("================detect_imagefile======================")
image_filename = image_filepath.encode('ascii')
rs = detect_imagefile(net, meta, image_filename, thresh=0.5)
else:
print("================detect_imagebuffer======================")
image_filename = image_filepath # test.jpg
bgr = cv2.imread(image_filename)
since = time.time()
# detect
rs = detect_imagebuffer(net, meta, bgr, thresh=0.5)
time_elapsed = time.time() - since
print("%f real seconds" % time_elapsed)
print("len = ", len(rs))
for r in rs:
# x1, y1, x2, y2
print(r) # ('knife', 0.9975, (522.5, 482.7, 155.0, 342.8))
class_name = r[0]
score = r[1]
box4_tuple = r[2] # (cx,cy,w,h)
cx = box4_tuple[0]
cy = box4_tuple[1]
w = box4_tuple[2]
h = box4_tuple[3]
x = cx - w / 2
y = cy - h / 2
box = [int(x), int(y), int(x + w), int(y + h)]
print(box)
print(score)
print(class_name)
bgr_color = (0, 0, 255)
cv2.rectangle(bgr, (box[0], box[1]), (box[2], box[3]), bgr_color, 2)
cv2.putText(bgr, class_name, (int(x), int(y-5)), cv2.FONT_HERSHEY_SIMPLEX, 1, bgr_color, 2)
pad_index = "{0:06d}".format(index)
filepath = output_dir + pad_index + "_image_with_boxs.jpg"
cv2.imwrite(filepath , bgr)
print("saved ",filepath)
# skip the first image
if index > 0:
total_cost += time_elapsed
if image_size - 1 > 0:
avg_cost = total_cost / ((image_size - 1) * 1.0)
msg = "image size ={}, total cost={} s, avg cost ={} ms".format(image_size - 1, int(total_cost),
int(avg_cost * 1000))
print(msg)
"""
========================================================
linux: 20181203:
image size =782, total cost=64 s, avg cost =82 ms
"""
if __name__ == "__main__":
#yolov3()
xray()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment