Skip to content

Instantly share code, notes, and snippets.

@adhadse
Last active March 17, 2021 08:21
Show Gist options
  • Save adhadse/46a381a8de35d158b75a0e79d63060a5 to your computer and use it in GitHub Desktop.
Save adhadse/46a381a8de35d158b75a0e79d63060a5 to your computer and use it in GitHub Desktop.
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds, metric_name=None, metric_perc=None):
plt.figure(figsize=(15, 10))
plt.plot(thresholds, precisions[:-1], 'b--', label='Precision')
plt.plot(thresholds, recalls[:-1], 'g-', label='Recall')
plt.xlabel('Threshold', fontsize=15)
plt.axis([-50000, 50000, 0, 1])
plt.legend(loc='best', fontsize=16)
if metric_name=='precision':
metric = precisions
tradedoff_metric = recalls
tradedoff_metric_name = 'Recall'
# tradedoff metric & threshold at percentage metric we want.
tradedoff_atperc_metric = tradedoff_metric[np.argmax(metric >= metric_perc)]
threshold_atperc_metric = thresholds[np.argmax(metric >= metric_perc)]
elif metric_name == 'recall':
metric = recalls
tradedoff_metric = precisions
tradedoff_metric_name = 'Precision'
# tradedoff metric & threshold at percentage metric we want.
tradedoff_atperc_metric = tradedoff_metric[np.argmax(metric <= metric_perc)]
threshold_atperc_metric = thresholds[np.argmax(metric <= metric_perc)]
else:
return
# Draw the threholds, red dotted vertical line
plt.plot([threshold_atperc_metric, threshold_atperc_metric], [0., metric_perc], "r:")
# Draw the two horizontal dotted line for precision and recall
plt.plot([-50000, threshold_atperc_metric], [metric_perc, metric_perc], "r:")
plt.plot([-50000, threshold_atperc_metric], [tradedoff_atperc_metric, tradedoff_atperc_metric ], "r:")
# Draw the two dots
plt.plot([threshold_atperc_metric], [metric_perc], "ro")
plt.plot([threshold_atperc_metric], [tradedoff_atperc_metric], "ro")
plt.title("Precision/Recall vs Threshold plot with Threshold Set to {}% of {}\n trading off {} at {:.3f}%".format(
metric_perc*100,
metric_name.capitalize(),
tradedoff_metric_name,
tradedoff_atperc_metric*100), fontsize=20)
plt.show()
return threshold_atperc_metric
threshold = plot_precision_recall_vs_threshold(precisions, recalls, thresholds, 'precision', 0.9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment