Skip to content

Instantly share code, notes, and snippets.

@Multihuntr
Created July 28, 2021 11:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Multihuntr/5a898e1794808ff7c6d30efca2ff52b7 to your computer and use it in GitHub Desktop.
Save Multihuntr/5a898e1794808ff7c6d30efca2ff52b7 to your computer and use it in GitHub Desktop.
Mean average precision for object detection
# Reimplementation of: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/object_detection/metrics/mean_avg_precision.py
# Now with more vectorisation!
def precision_recall_curve_th(is_tp, confs, n_true, eps=1e-8):
# Sort by confs
order = (-confs).argsort()
is_tp = is_tp[order]
confs = confs[order]
# Cumulative sum true positives and number of predictions
TP = is_tp.cumsum(dim=0)
n_pred = torch.arange(len(is_tp))+1
# Divide by different subsets to find recall/precision
precisions = TP / (n_pred + eps)
recalls = TP / (n_true + eps)
return precisions, recalls
def mean_average_precision(pred_boxes, true_boxes, iou_thresh=0.5, box_format="corners"):
"""
Calculates mean average precision
Parameters:
pred_boxes (list): list of lists containing all bboxes with each bboxes
specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
true_boxes (list): Similar as pred_boxes except all the correct ones
iou_threshold (float): threshold where predicted bboxes is correct
Returns:
float: mAP value across all classes given a specific IoU threshold
"""
average_precisions = []
classes = set(true_boxes[:, 1].tolist())
for c in classes:
# Get only the boxes for class with index c
detections = pred_boxes[pred_boxes[:, 1] == c]
ground_truths = true_boxes[true_boxes[:, 1] == c]
total_true_bboxes = len(ground_truths)
is_tps = []
confs = []
for i in set(ground_truths[:, 0].tolist()):
# Get only the boxes for image with index i
det_i = detections[detections[:, 0] == i]
gt_i = ground_truths[ground_truths[:, 0] == i]
# Calculate IoUs for all pairs of det/gt
ious = intersection_over_union(det_i[:, None, 2:], gt_i[None, :, 2:], box_format=box_format)
ious = ious.squeeze(-1)
# Remove all gt boxes which don't have any detections close enough
gt_max, _ = ious.max(dim=0)
ious = ious[:, gt_max >= iou_thresh]
# Select the first det box above iou_thresh for each remaining gt
_, det_max_idx = (ious >= iou_thresh).max(dim=0)
is_tp = torch.zeros(det_i.shape[0])
is_tp[det_max_idx] = 1
is_tps.append(is_tp)
confs.append(det_i[:, 2])
is_tps = torch.cat(is_tps)
confs = torch.cat(confs)
# Find average_precision for this class
precision, recall = precision_recall_curve_th(is_tps, confs, total_true_bboxes)
precision = torch.cat((torch.tensor([1]), precision))
recall = torch.cat((torch.tensor([0]), recall))
average_precisions.append(torch.trapz(precision, recall))
return sum(average_precisions) / len(average_precisions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment