Skip to content

Instantly share code, notes, and snippets.

@VincentTatan
Created May 9, 2020 12:51
Show Gist options
  • Save VincentTatan/b725ae77994a5ea2a9b4bacf474e821d to your computer and use it in GitHub Desktop.
Save VincentTatan/b725ae77994a5ea2a9b4bacf474e821d to your computer and use it in GitHub Desktop.
def take_roc_curve(X_test,model):
y_preds = model.predict_proba(X_test)
preds = y_preds[:,1]
fpr, tpr, _ = metrics.roc_curve(y_test, preds)
precision, recall, _ = metrics.precision_recall_curve(y_test, preds)
auc_score = metrics.auc(fpr, tpr)
plt.figure(figsize=(10,5))
plt.subplot(1, 2, 1)
plt.title('ROC Curve '+type(model).__name__)
plt.plot(fpr, tpr, label='AUC = {:.2f}'.format(auc_score))
plt.plot([0,1],[0,1],'r--')
plt.xlim([-0.1,1.1])
plt.ylim([-0.1,1.1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.legend(loc='lower right')
plt.subplot(1, 2, 2)
plt.step(recall, precision, color='orange', where='post')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('Precision Recall Curve')
plt.grid(True)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment