Skip to content

Instantly share code, notes, and snippets.

@missflash
Created March 28, 2019 12:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save missflash/80fd6c8fe74a7f9ef9b7b594c6584704 to your computer and use it in GitHub Desktop.
Save missflash/80fd6c8fe74a7f9ef9b7b594c6584704 to your computer and use it in GitHub Desktop.
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, roc_curve
def performance(y_true, pred, color="g", ann=True):
acc = accuracy_score(y_true, pred[:,1] > 0.5)
auc = roc_auc_score(y_true, pred[:,1])
fpr, tpr, thr = roc_curve(y_true, pred[:,1])
plot(fpr, tpr, color, linewidth="3")
xlabel("False positive rate")
ylabel("True positive rate")
if ann:
annotate("Acc: %0.2f" % acc, (0.1,0.8), size=14)
annotate("AUC: %0.2f" % auc, (0.1,0.7), size=14)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment