Last active
March 17, 2021 08:21
-
-
Save adhadse/46a381a8de35d158b75a0e79d63060a5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds, metric_name=None, metric_perc=None): | |
plt.figure(figsize=(15, 10)) | |
plt.plot(thresholds, precisions[:-1], 'b--', label='Precision') | |
plt.plot(thresholds, recalls[:-1], 'g-', label='Recall') | |
plt.xlabel('Threshold', fontsize=15) | |
plt.axis([-50000, 50000, 0, 1]) | |
plt.legend(loc='best', fontsize=16) | |
if metric_name=='precision': | |
metric = precisions | |
tradedoff_metric = recalls | |
tradedoff_metric_name = 'Recall' | |
# tradedoff metric & threshold at percentage metric we want. | |
tradedoff_atperc_metric = tradedoff_metric[np.argmax(metric >= metric_perc)] | |
threshold_atperc_metric = thresholds[np.argmax(metric >= metric_perc)] | |
elif metric_name == 'recall': | |
metric = recalls | |
tradedoff_metric = precisions | |
tradedoff_metric_name = 'Precision' | |
# tradedoff metric & threshold at percentage metric we want. | |
tradedoff_atperc_metric = tradedoff_metric[np.argmax(metric <= metric_perc)] | |
threshold_atperc_metric = thresholds[np.argmax(metric <= metric_perc)] | |
else: | |
return | |
# Draw the threholds, red dotted vertical line | |
plt.plot([threshold_atperc_metric, threshold_atperc_metric], [0., metric_perc], "r:") | |
# Draw the two horizontal dotted line for precision and recall | |
plt.plot([-50000, threshold_atperc_metric], [metric_perc, metric_perc], "r:") | |
plt.plot([-50000, threshold_atperc_metric], [tradedoff_atperc_metric, tradedoff_atperc_metric ], "r:") | |
# Draw the two dots | |
plt.plot([threshold_atperc_metric], [metric_perc], "ro") | |
plt.plot([threshold_atperc_metric], [tradedoff_atperc_metric], "ro") | |
plt.title("Precision/Recall vs Threshold plot with Threshold Set to {}% of {}\n trading off {} at {:.3f}%".format( | |
metric_perc*100, | |
metric_name.capitalize(), | |
tradedoff_metric_name, | |
tradedoff_atperc_metric*100), fontsize=20) | |
plt.show() | |
return threshold_atperc_metric | |
threshold = plot_precision_recall_vs_threshold(precisions, recalls, thresholds, 'precision', 0.9) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment