Created
January 5, 2023 03:01
-
-
Save realphongha/f5406d4cf059d0ce3c2f69ab81166420 to your computer and use it in GitHub Desktop.
Multiclass nms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def iou_calc(boxes1, boxes2): | |
boxes1_area = (boxes1[2] - boxes1[0]) * (boxes1[3] - boxes1[1]) | |
boxes2_area = (boxes2[2] - boxes2[0]) * (boxes2[3] - boxes2[1]) | |
left_up = np.maximum(boxes1[:2], boxes2[:2]) | |
right_down = np.minimum(boxes1[2:-2], boxes2[2:-2]) | |
inter_section = np.maximum(right_down - left_up, 0.0) | |
inter_area = inter_section[0] * inter_section[1] | |
return 1.0 * inter_area / (boxes1_area+boxes2_area-inter_area) | |
def multiclass_nms(boxes, iou_threshold=0.75, conf_threshold=None, | |
max_nms=1000): | |
if boxes.shape[0] == 0: | |
return list() | |
if conf_threshold: | |
keep_idx = [] | |
for i, box in enumerate(boxes): | |
score, cls = box[4:] | |
if conf_threshold[round(cls)] <= score: | |
keep_idx.append(i) | |
boxes = boxes[keep_idx] | |
# sort by confidence | |
sorted_i = boxes[:, 4].argsort()[::-1] | |
if boxes.shape[0] > max_nms: | |
sorted_i = sorted_i[:max_nms] | |
boxes = boxes[sorted_i] | |
return_box = [] | |
boxes_dict = {} | |
for box in boxes: | |
if box[5] in boxes_dict: | |
boxes_dict[box[5]].append(box) | |
else: | |
boxes_dict[box[5]] = [box] | |
for boxs in boxes_dict.values(): | |
if len(boxs) == 1: | |
return_box.append(boxs[0]) | |
else: | |
while(boxs): | |
best_box = boxs.pop(0) | |
return_box.append(best_box) | |
j = 0 | |
for i in range(len(boxs)): | |
i -= j | |
if iou_calc(best_box, boxs[i]) > iou_threshold: | |
boxs.pop(i) | |
j += 1 | |
return return_box | |
if __name__ == "__main__": | |
import numpy as np | |
bboxes = [ | |
[0, 0, 100, 100, 0.6, 0], | |
[10, 10, 100, 100, 0.65, 0], | |
[50, 50, 100, 100, 0.6, 1], | |
[0, 0, 80, 80, 0.8, 1], | |
[110, 110, 200, 200, 0.6, 0], | |
] | |
bboxes = np.array(bboxes, dtype=np.float32) | |
iou = 0.2 # iou threshold | |
score = [0.6, 0.7] # score threshold for each class | |
print(multiclass_nms(bboxes, iou, score)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment