Created
March 28, 2018 18:18
-
-
Save geek-at/9ba20c4a9431164e8262d77d338b0748 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ctypes import * | |
import math | |
import random | |
import time | |
import os | |
def sample(probs): | |
s = sum(probs) | |
probs = [a/s for a in probs] | |
r = random.uniform(0, 1) | |
for i in range(len(probs)): | |
r = r - probs[i] | |
if r <= 0: | |
return i | |
return len(probs)-1 | |
def c_array(ctype, values): | |
arr = (ctype*len(values))() | |
arr[:] = values | |
return arr | |
class BOX(Structure): | |
_fields_ = [("x", c_float), | |
("y", c_float), | |
("w", c_float), | |
("h", c_float)] | |
class IMAGE(Structure): | |
_fields_ = [("w", c_int), | |
("h", c_int), | |
("c", c_int), | |
("data", POINTER(c_float))] | |
class METADATA(Structure): | |
_fields_ = [("classes", c_int), | |
("names", POINTER(c_char_p))] | |
#lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL) | |
lib = CDLL("libdarknet.so", RTLD_GLOBAL) | |
lib.network_width.argtypes = [c_void_p] | |
lib.network_width.restype = c_int | |
lib.network_height.argtypes = [c_void_p] | |
lib.network_height.restype = c_int | |
predict = lib.network_predict | |
predict.argtypes = [c_void_p, POINTER(c_float)] | |
predict.restype = POINTER(c_float) | |
set_gpu = lib.cuda_set_device | |
set_gpu.argtypes = [c_int] | |
make_image = lib.make_image | |
make_image.argtypes = [c_int, c_int, c_int] | |
make_image.restype = IMAGE | |
make_boxes = lib.make_boxes | |
make_boxes.argtypes = [c_void_p] | |
make_boxes.restype = POINTER(BOX) | |
free_ptrs = lib.free_ptrs | |
free_ptrs.argtypes = [POINTER(c_void_p), c_int] | |
num_boxes = lib.num_boxes | |
num_boxes.argtypes = [c_void_p] | |
num_boxes.restype = c_int | |
make_probs = lib.make_probs | |
make_probs.argtypes = [c_void_p] | |
make_probs.restype = POINTER(POINTER(c_float)) | |
detect = lib.network_predict | |
detect.argtypes = [c_void_p, IMAGE, c_float, c_float, c_float, POINTER(BOX), POINTER(POINTER(c_float))] | |
reset_rnn = lib.reset_rnn | |
reset_rnn.argtypes = [c_void_p] | |
load_net = lib.load_network | |
load_net.argtypes = [c_char_p, c_char_p, c_int] | |
load_net.restype = c_void_p | |
free_image = lib.free_image | |
free_image.argtypes = [IMAGE] | |
letterbox_image = lib.letterbox_image | |
letterbox_image.argtypes = [IMAGE, c_int, c_int] | |
letterbox_image.restype = IMAGE | |
load_meta = lib.get_metadata | |
lib.get_metadata.argtypes = [c_char_p] | |
lib.get_metadata.restype = METADATA | |
load_image = lib.load_image_color | |
load_image.argtypes = [c_char_p, c_int, c_int] | |
load_image.restype = IMAGE | |
rgbgr_image = lib.rgbgr_image | |
rgbgr_image.argtypes = [IMAGE] | |
predict_image = lib.network_predict_image | |
predict_image.argtypes = [c_void_p, IMAGE] | |
predict_image.restype = POINTER(c_float) | |
network_detect = lib.network_detect | |
network_detect.argtypes = [c_void_p, IMAGE, c_float, c_float, c_float, POINTER(BOX), POINTER(POINTER(c_float))] | |
def classify(net, meta, im): | |
out = predict_image(net, im) | |
res = [] | |
for i in range(meta.classes): | |
res.append((meta.names[i], out[i])) | |
res = sorted(res, key=lambda x: -x[1]) | |
return res | |
def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): | |
im = load_image(image, 0, 0) | |
boxes = make_boxes(net) | |
probs = make_probs(net) | |
num = num_boxes(net) | |
network_detect(net, im, thresh, hier_thresh, nms, boxes, probs) | |
res = [] | |
for j in range(num): | |
for i in range(meta.classes): | |
if probs[j][i] > 0: | |
res.append((meta.names[i], probs[j][i], (boxes[j].x, boxes[j].y, boxes[j].w, boxes[j].h))) | |
res = sorted(res, key=lambda x: -x[1]) | |
free_image(im) | |
free_ptrs(cast(probs, POINTER(c_void_p)), num) | |
return res | |
if __name__ == "__main__": | |
target = "/tmp/ramdisk/frame.jpg" | |
net = load_net("cfg/yolo.cfg", "yolo.weights", 0) | |
meta = load_meta("cfg/coco.data") | |
while True: | |
# Detect objects | |
output = detect(net, meta, target, 0.4) | |
# clear screen for better readability | |
os.system("clear") | |
# loop every object you see and print it | |
for i in range(len(output)): | |
obj = output[i][0] | |
percent = round(output[i][1] * 100) | |
print("I'm {}% sure I see a {}".format(percent,obj)) | |
# sleep for a second | |
time.sleep(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment