Skip to content

Instantly share code, notes, and snippets.

@lucasb-eyer
Last active February 10, 2018 11:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lucasb-eyer/d11583922c429a756c6542e33a5f816c to your computer and use it in GitHub Desktop.
Save lucasb-eyer/d11583922c429a756c6542e33a5f816c to your computer and use it in GitHub Desktop.
Python code for 2D precision-recall computation of point-detections. Audited by @Pandoro :)
from collections import defaultdict
import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.metrics import auc
def prec_rec_2d(det_scores, det_coords, det_frames, gt_coords, gt_frames, gt_radii):
""" Computes full precision-recall curves at all possible thresholds.
Arguments:
- `det_scores` (D,) array containing the scores of the D detections.
- `det_coords` (D,2) array containing the (x,y) coordinates of the D detections.
- `det_frames` (D,) array containing the frame number of each of the D detections.
- `gt_coords` (L,2) array containing the (x,y) coordinates of the L labels (ground-truth detections).
- `gt_frames` (L,) array containing the frame number of each of the L labels.
- `gt_radii` (L,) array containing the radius at which each of the L labels should consider detection associations.
This will typically just be an np.full_like(gt_frames, 0.5) or similar,
but could vary when mixing classes, for example.
Returns: (recs, precs, threshs)
- `threshs`: (D,) array of sorted thresholds (scores), from higher to lower.
- `recs`: (D,) array of recall scores corresponding to the thresholds.
- `precs`: (D,) array of precision scores corresponding to the thresholds.
"""
# This means that all reported detection frames which are not in ground-truth frames
# will be counted as false-positives.
# TODO: do some sanity-checks in the "linearization" functions before calling `prec_rec_2d`.
frames = np.unique(np.r_[det_frames, gt_frames])
det_accepted_idxs = defaultdict(list)
tps = np.zeros(len(frames), dtype=np.uint32)
fps = np.zeros(len(frames), dtype=np.uint32)
fns = np.array([np.sum(gt_frames == f) for f in frames], dtype=np.uint32)
precs = np.full_like(det_scores, np.nan)
recs = np.full_like(det_scores, np.nan)
threshs = np.full_like(det_scores, np.nan)
indices = np.argsort(det_scores)
for i, idx in enumerate(reversed(indices)):
frame = det_frames[idx]
iframe = np.where(frames == frame)[0]
# Accept this detection
dets_idxs = det_accepted_idxs[frame]
dets_idxs.append(idx)
threshs[i] = det_scores[idx]
dets = det_coords[dets_idxs]
gts_mask = gt_frames == frame
gts = gt_coords[gts_mask]
radii = gt_radii[gts_mask]
if len(gts) == 0: # No GT, but there is a detection.
fps[iframe] += 1
else: # There is GT and detection in this frame.
not_in_radius = radii[:,None] < cdist(gts, dets) # -> ngts x ndets, True (=1) if too far, False (=0) if may match.
igt, idet = linear_sum_assignment(not_in_radius)
tps[iframe] = np.sum(np.logical_not(not_in_radius[igt, idet])) # Could match within radius
fps[iframe] = len(dets) - tps[iframe] # NB: dets is only the so-far accepted.
fns[iframe] = len(gts) - tps[iframe]
tp, fp, fn = np.sum(tps), np.sum(fps), np.sum(fns)
precs[i] = tp/(fp+tp) if fp+tp > 0 else np.nan
recs[i] = tp/(fn+tp) if fn+tp > 0 else np.nan
return recs, precs, threshs
def peakf1(recs, precs):
return np.max(2*precs*recs/np.clip(precs+recs, 1e-16, 2+1e-16))
def eer(recs, precs):
idx = np.argmin(np.abs(precs - recs))
return (precs[idx] + recs[idx])/2 # They are often the exact same, but if not, use average.
import matplotlib as mpl
import matplotlib.pyplot as plt
import lbtoolbox.plotting as lbplt
import lb_prec_rec as lbpr
def plot_prec_rec(result, figsize=(15,10), title=None):
""" `result` is what's returned by `lbpr.prec_rec_2d`. """
fig, ax = plt.subplots(figsize=figsize)
label = 'detections (AUC: {:.1%}, F1: {:.1%}, EER: {:.1%})'.format(lbpr.auc(*result[:2]), lbpr.peakf1(*result[:2]), lbpr.eer(*result[:2]))
ax.plot(*result[:2], label=label, c='#E24A33')
if title is not None:
fig.suptitle(title, fontsize=16, y=0.91)
prettify_pr_curve(ax)
lbplt.fatlegend(ax, loc='upper right')
return fig, ax
def prettify_pr_curve(ax):
ax.plot([0,1], [0,1], ls="--", c=".6")
ax.set_xlim(-0.02,1.02)
ax.set_ylim(-0.02,1.02)
ax.set_xlabel("Recall [%]")
ax.set_ylabel("Precision [%]")
ax.axes.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: '{:.0f}'.format(x*100)))
ax.axes.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: '{:.0f}'.format(x*100)))
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment