Skip to content

Instantly share code, notes, and snippets.

@bearpelican
Created July 31, 2020 02:34
Show Gist options
  • Save bearpelican/7f3fce5d354369b512f23893af9476b3 to your computer and use it in GitHub Desktop.
Save bearpelican/7f3fce5d354369b512f23893af9476b3 to your computer and use it in GitHub Desktop.
def calc_class_ap(lbls, preds, max_dist=30):
metrics = {}
for c in classes:
metrics[c] = [],[],[]
for img_id,pred_row in preds.items():
lbl_row = lbls[img_id]
for cls in classes:
tps, fps, n_gts = metrics[cls]
pred_pts = pred_row[cls] if cls in pred_row else []
lbl_pts = lbl_row[cls] if cls in lbl_row else []
detected = []
dists = []
for pred_pt in pred_pts:
calc = [(euclidean_dist(lb, pred_pt), lb) for lb in lbl_pts if lb not in detected]
if len(calc) == 0:
dists.append(1e10)
else:
dist, lb = min(calc)
detected.append(lb)
dists.append(dist)
tp = (np.array(dists) < max_dist)
fp = ~tp
n_gts.append(len(lbl_pts))
tps.extend(tp.astype(np.uint8).tolist())
fps.extend(fp.astype(np.uint8).tolist())
# Calculate average precision per class
aps = []
for idx,c in enumerate(classes):
tps, fps, n_gts = metrics[c]
tp = np.array(tps).cumsum(0) if tps else np.array([0])
fp = np.array(fps).cumsum(0) if fps else np.array([0])
precision = tp / (tp + fp + 1e-8)
recall = tp / sum(n_gts)
aps.append({
'class': c,
'ap':round(compute_ap(precision, recall), 4),
'tp': int(tp[-1]),
'fp': int(fp[-1]),
'gt': sum(n_gts)
})
return aps
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment