Skip to content

Instantly share code, notes, and snippets.

@parulnith
Created October 1, 2020 11:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save parulnith/48649e0c82dbb59c6f36e7a507fa1eef to your computer and use it in GitHub Desktop.
Save parulnith/48649e0c82dbb59c6f36e7a507fa1eef to your computer and use it in GitHub Desktop.
%matplotlib inline
from sklearn.metrics import roc_curve, precision_recall_curve, auc
import matplotlib.pyplot as plt
import numpy as np
def get_auc(labels, scores):
fpr, tpr, thresholds = roc_curve(labels, scores)
auc_score = auc(fpr, tpr)
return fpr, tpr, auc_score
def get_aucpr(labels, scores):
precision, recall, th = precision_recall_curve(labels, scores)
aucpr_score = np.trapz(recall, precision)
return precision, recall, aucpr_score
def plot_metric(ax, x, y, x_label, y_label, plot_label, style="-"):
ax.plot(x, y, style, label=plot_label)
ax.legend()
ax.set_ylabel(x_label)
ax.set_xlabel(y_label)
def prediction_summary(labels, predicted_score, info, plot_baseline=True, axes=None):
if axes is None:
axes = [plt.subplot(1, 2, 1), plt.subplot(1, 2, 2)]
fpr, tpr, auc_score = get_auc(labels, predicted_score)
plot_metric(axes[0], fpr, tpr, "False positive rate",
"True positive rate", "{} AUC = {:.4f}".format(info, auc_score))
if plot_baseline:
plot_metric(axes[0], [0, 1], [0, 1], "False positive rate",
"True positive rate", "baseline AUC = 0.5", "r--")
precision, recall, aucpr_score = get_aucpr(labels, predicted_score)
plot_metric(axes[1], recall, precision, "Recall",
"Precision", "{} AUCPR = {:.4f}".format(info, aucpr_score))
if plot_baseline:
thr = sum(labels)/len(labels)
plot_metric(axes[1], [0, 1], [thr, thr], "Recall",
"Precision", "baseline AUCPR = {:.4f}".format(thr), "r--")
plt.show()
return axes
def figure():
fig_size = 4.5
f = plt.figure()
f.set_figheight(fig_size)
f.set_figwidth(fig_size*2)
h2o_predictions = predictions.as_data_frame()
figure()
axes = prediction_summary(
h2o_predictions["class"], h2o_predictions["predict"], "h2o")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment