Skip to content

Instantly share code, notes, and snippets.

@miki998
Created May 2, 2020 14:58
Show Gist options
  • Save miki998/79585196a3c4acccd6971323778ceb78 to your computer and use it in GitHub Desktop.
Save miki998/79585196a3c4acccd6971323778ceb78 to your computer and use it in GitHub Desktop.
def box(image, boxes, class_names=None):
colors = torch.FloatTensor([[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]]);
img = image.copy()
width = img.shape[0]
height = img.shape[1]
for i in range(len(boxes)):
box = boxes[i]
x1,y1 = (box[0] - box[2] / 2.0) * width, (box[1] - box[3] / 2.0) * height
x2,y2 = (box[0] + box[2] / 2.0) * width, (box[1] + box[3] / 2.0) * height
#just swapping because people are stupid
x1,y1 = y1,x1
x2,y2 = y2,x2
if len(box) >= 7 and class_names:
cls_conf = box[5]
cls_id = box[6]
img = cv2.putText(img, str(class_names[cls_id]) , (int(x1),int(y1)), font, 1, (0,255,255), 2, cv2.LINE_AA)
img = cv2.rectangle(img, (int(x2),int(y2)) , (int(x1),int(y1)), (255,0,255), 2)
return img
def detect(cfgfile, weightfile, img,verbose=1):
m = Darknet(cfgfile)
if verbose: m.print_network()
m.load_weights(weightfile)
if verbose: print('Loading weights from %s... Done!' % (weightfile))
num_classes = 80
if num_classes == 20:
namesfile = 'data/voc.names'
elif num_classes == 80:
namesfile = 'data/coco.names'
else:
namesfile = 'data/names'
use_cuda = 0
if use_cuda:
m.cuda()
sized = cv2.resize(img,(m.width, m.height),interpolation=cv2.INTER_AREA)
for i in range(2):
start = time.time()
boxes = do_detect(m, sized, 0.5, 0.4, use_cuda)
finish = time.time()
if i == 1 and verbose:
print('Predicted in {} seconds.'.format(finish - start))
class_names = load_class_names(namesfile)
boxed_img = box(img, boxes, class_names=class_names)
return img,boxes,class_names, boxed_img
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment