Last active
February 10, 2018 11:12
-
-
Save lucasb-eyer/d11583922c429a756c6542e33a5f816c to your computer and use it in GitHub Desktop.
Python code for 2D precision-recall computation of point-detections. Audited by @Pandoro :)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import numpy as np | |
from scipy.optimize import linear_sum_assignment | |
from scipy.spatial.distance import cdist | |
from sklearn.metrics import auc | |
def prec_rec_2d(det_scores, det_coords, det_frames, gt_coords, gt_frames, gt_radii): | |
""" Computes full precision-recall curves at all possible thresholds. | |
Arguments: | |
- `det_scores` (D,) array containing the scores of the D detections. | |
- `det_coords` (D,2) array containing the (x,y) coordinates of the D detections. | |
- `det_frames` (D,) array containing the frame number of each of the D detections. | |
- `gt_coords` (L,2) array containing the (x,y) coordinates of the L labels (ground-truth detections). | |
- `gt_frames` (L,) array containing the frame number of each of the L labels. | |
- `gt_radii` (L,) array containing the radius at which each of the L labels should consider detection associations. | |
This will typically just be an np.full_like(gt_frames, 0.5) or similar, | |
but could vary when mixing classes, for example. | |
Returns: (recs, precs, threshs) | |
- `threshs`: (D,) array of sorted thresholds (scores), from higher to lower. | |
- `recs`: (D,) array of recall scores corresponding to the thresholds. | |
- `precs`: (D,) array of precision scores corresponding to the thresholds. | |
""" | |
# This means that all reported detection frames which are not in ground-truth frames | |
# will be counted as false-positives. | |
# TODO: do some sanity-checks in the "linearization" functions before calling `prec_rec_2d`. | |
frames = np.unique(np.r_[det_frames, gt_frames]) | |
det_accepted_idxs = defaultdict(list) | |
tps = np.zeros(len(frames), dtype=np.uint32) | |
fps = np.zeros(len(frames), dtype=np.uint32) | |
fns = np.array([np.sum(gt_frames == f) for f in frames], dtype=np.uint32) | |
precs = np.full_like(det_scores, np.nan) | |
recs = np.full_like(det_scores, np.nan) | |
threshs = np.full_like(det_scores, np.nan) | |
indices = np.argsort(det_scores) | |
for i, idx in enumerate(reversed(indices)): | |
frame = det_frames[idx] | |
iframe = np.where(frames == frame)[0] | |
# Accept this detection | |
dets_idxs = det_accepted_idxs[frame] | |
dets_idxs.append(idx) | |
threshs[i] = det_scores[idx] | |
dets = det_coords[dets_idxs] | |
gts_mask = gt_frames == frame | |
gts = gt_coords[gts_mask] | |
radii = gt_radii[gts_mask] | |
if len(gts) == 0: # No GT, but there is a detection. | |
fps[iframe] += 1 | |
else: # There is GT and detection in this frame. | |
not_in_radius = radii[:,None] < cdist(gts, dets) # -> ngts x ndets, True (=1) if too far, False (=0) if may match. | |
igt, idet = linear_sum_assignment(not_in_radius) | |
tps[iframe] = np.sum(np.logical_not(not_in_radius[igt, idet])) # Could match within radius | |
fps[iframe] = len(dets) - tps[iframe] # NB: dets is only the so-far accepted. | |
fns[iframe] = len(gts) - tps[iframe] | |
tp, fp, fn = np.sum(tps), np.sum(fps), np.sum(fns) | |
precs[i] = tp/(fp+tp) if fp+tp > 0 else np.nan | |
recs[i] = tp/(fn+tp) if fn+tp > 0 else np.nan | |
return recs, precs, threshs | |
def peakf1(recs, precs): | |
return np.max(2*precs*recs/np.clip(precs+recs, 1e-16, 2+1e-16)) | |
def eer(recs, precs): | |
idx = np.argmin(np.abs(precs - recs)) | |
return (precs[idx] + recs[idx])/2 # They are often the exact same, but if not, use average. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import lbtoolbox.plotting as lbplt | |
import lb_prec_rec as lbpr | |
def plot_prec_rec(result, figsize=(15,10), title=None): | |
""" `result` is what's returned by `lbpr.prec_rec_2d`. """ | |
fig, ax = plt.subplots(figsize=figsize) | |
label = 'detections (AUC: {:.1%}, F1: {:.1%}, EER: {:.1%})'.format(lbpr.auc(*result[:2]), lbpr.peakf1(*result[:2]), lbpr.eer(*result[:2])) | |
ax.plot(*result[:2], label=label, c='#E24A33') | |
if title is not None: | |
fig.suptitle(title, fontsize=16, y=0.91) | |
prettify_pr_curve(ax) | |
lbplt.fatlegend(ax, loc='upper right') | |
return fig, ax | |
def prettify_pr_curve(ax): | |
ax.plot([0,1], [0,1], ls="--", c=".6") | |
ax.set_xlim(-0.02,1.02) | |
ax.set_ylim(-0.02,1.02) | |
ax.set_xlabel("Recall [%]") | |
ax.set_ylabel("Precision [%]") | |
ax.axes.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: '{:.0f}'.format(x*100))) | |
ax.axes.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: '{:.0f}'.format(x*100))) | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment