Created
October 1, 2020 11:28
-
-
Save parulnith/48649e0c82dbb59c6f36e7a507fa1eef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%matplotlib inline | |
from sklearn.metrics import roc_curve, precision_recall_curve, auc | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def get_auc(labels, scores): | |
fpr, tpr, thresholds = roc_curve(labels, scores) | |
auc_score = auc(fpr, tpr) | |
return fpr, tpr, auc_score | |
def get_aucpr(labels, scores): | |
precision, recall, th = precision_recall_curve(labels, scores) | |
aucpr_score = np.trapz(recall, precision) | |
return precision, recall, aucpr_score | |
def plot_metric(ax, x, y, x_label, y_label, plot_label, style="-"): | |
ax.plot(x, y, style, label=plot_label) | |
ax.legend() | |
ax.set_ylabel(x_label) | |
ax.set_xlabel(y_label) | |
def prediction_summary(labels, predicted_score, info, plot_baseline=True, axes=None): | |
if axes is None: | |
axes = [plt.subplot(1, 2, 1), plt.subplot(1, 2, 2)] | |
fpr, tpr, auc_score = get_auc(labels, predicted_score) | |
plot_metric(axes[0], fpr, tpr, "False positive rate", | |
"True positive rate", "{} AUC = {:.4f}".format(info, auc_score)) | |
if plot_baseline: | |
plot_metric(axes[0], [0, 1], [0, 1], "False positive rate", | |
"True positive rate", "baseline AUC = 0.5", "r--") | |
precision, recall, aucpr_score = get_aucpr(labels, predicted_score) | |
plot_metric(axes[1], recall, precision, "Recall", | |
"Precision", "{} AUCPR = {:.4f}".format(info, aucpr_score)) | |
if plot_baseline: | |
thr = sum(labels)/len(labels) | |
plot_metric(axes[1], [0, 1], [thr, thr], "Recall", | |
"Precision", "baseline AUCPR = {:.4f}".format(thr), "r--") | |
plt.show() | |
return axes | |
def figure(): | |
fig_size = 4.5 | |
f = plt.figure() | |
f.set_figheight(fig_size) | |
f.set_figwidth(fig_size*2) | |
h2o_predictions = predictions.as_data_frame() | |
figure() | |
axes = prediction_summary( | |
h2o_predictions["class"], h2o_predictions["predict"], "h2o") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment