This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ctypes import * | |
import math | |
import random | |
import time | |
import os | |
def sample(probs): | |
s = sum(probs) | |
probs = [a/s for a in probs] | |
r = random.uniform(0, 1) | |
for i in range(len(probs)): | |
r = r - probs[i] | |
if r <= 0: | |
return i | |
return len(probs)-1 | |
def c_array(ctype, values): | |
arr = (ctype*len(values))() | |
arr[:] = values | |
return arr | |
class BOX(Structure): | |
_fields_ = [("x", c_float), | |
("y", c_float), | |
("w", c_float), | |
("h", c_float)] | |
class IMAGE(Structure): | |
_fields_ = [("w", c_int), | |
("h", c_int), | |
("c", c_int), | |
("data", POINTER(c_float))] | |
class METADATA(Structure): | |
_fields_ = [("classes", c_int), | |
("names", POINTER(c_char_p))] | |
#lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL) | |
lib = CDLL("libdarknet.so", RTLD_GLOBAL) | |
lib.network_width.argtypes = [c_void_p] | |
lib.network_width.restype = c_int | |
lib.network_height.argtypes = [c_void_p] | |
lib.network_height.restype = c_int | |
predict = lib.network_predict | |
predict.argtypes = [c_void_p, POINTER(c_float)] | |
predict.restype = POINTER(c_float) | |
set_gpu = lib.cuda_set_device | |
set_gpu.argtypes = [c_int] | |
make_image = lib.make_image | |
make_image.argtypes = [c_int, c_int, c_int] | |
make_image.restype = IMAGE | |
make_boxes = lib.make_boxes | |
make_boxes.argtypes = [c_void_p] | |
make_boxes.restype = POINTER(BOX) | |
free_ptrs = lib.free_ptrs | |
free_ptrs.argtypes = [POINTER(c_void_p), c_int] | |
num_boxes = lib.num_boxes | |
num_boxes.argtypes = [c_void_p] | |
num_boxes.restype = c_int | |
make_probs = lib.make_probs | |
make_probs.argtypes = [c_void_p] | |
make_probs.restype = POINTER(POINTER(c_float)) | |
detect = lib.network_predict | |
detect.argtypes = [c_void_p, IMAGE, c_float, c_float, c_float, POINTER(BOX), POINTER(POINTER(c_float))] | |
reset_rnn = lib.reset_rnn | |
reset_rnn.argtypes = [c_void_p] | |
load_net = lib.load_network | |
load_net.argtypes = [c_char_p, c_char_p, c_int] | |
load_net.restype = c_void_p | |
free_image = lib.free_image | |
free_image.argtypes = [IMAGE] | |
letterbox_image = lib.letterbox_image | |
letterbox_image.argtypes = [IMAGE, c_int, c_int] | |
letterbox_image.restype = IMAGE | |
load_meta = lib.get_metadata | |
lib.get_metadata.argtypes = [c_char_p] | |
lib.get_metadata.restype = METADATA | |
load_image = lib.load_image_color | |
load_image.argtypes = [c_char_p, c_int, c_int] | |
load_image.restype = IMAGE | |
rgbgr_image = lib.rgbgr_image | |
rgbgr_image.argtypes = [IMAGE] | |
predict_image = lib.network_predict_image | |
predict_image.argtypes = [c_void_p, IMAGE] | |
predict_image.restype = POINTER(c_float) | |
network_detect = lib.network_detect | |
network_detect.argtypes = [c_void_p, IMAGE, c_float, c_float, c_float, POINTER(BOX), POINTER(POINTER(c_float))] | |
def classify(net, meta, im): | |
out = predict_image(net, im) | |
res = [] | |
for i in range(meta.classes): | |
res.append((meta.names[i], out[i])) | |
res = sorted(res, key=lambda x: -x[1]) | |
return res | |
def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): | |
im = load_image(image, 0, 0) | |
boxes = make_boxes(net) | |
probs = make_probs(net) | |
num = num_boxes(net) | |
network_detect(net, im, thresh, hier_thresh, nms, boxes, probs) | |
res = [] | |
for j in range(num): | |
for i in range(meta.classes): | |
if probs[j][i] > 0: | |
res.append((meta.names[i], probs[j][i], (boxes[j].x, boxes[j].y, boxes[j].w, boxes[j].h))) | |
res = sorted(res, key=lambda x: -x[1]) | |
free_image(im) | |
free_ptrs(cast(probs, POINTER(c_void_p)), num) | |
return res | |
if __name__ == "__main__": | |
target = "/tmp/ramdisk/frame.jpg" | |
net = load_net("cfg/yolo.cfg", "yolo.weights", 0) | |
meta = load_meta("cfg/coco.data") | |
while True: | |
# Detect objects | |
output = detect(net, meta, target, 0.4) | |
# clear screen for better readability | |
os.system("clear") | |
# loop every object you see and print it | |
for i in range(len(output)): | |
obj = output[i][0] | |
percent = round(output[i][1] * 100) | |
print("I'm {}% sure I see a {}".format(percent,obj)) | |
# sleep for a second | |
time.sleep(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment