Created
September 26, 2020 15:35
-
-
Save coreyjs/0742ef12b1f557aa070c268142d9f629 to your computer and use it in GitHub Desktop.
Function for calculating auc and roc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Function for calculating auc and roc | |
def build_roc_auc(model, X_train, X_test, y_train, y_test): | |
''' | |
INPUT: | |
model - an sklearn instantiated model | |
X_train - the training data | |
y_train - the training response values (must be categorical) | |
X_test - the test data | |
y_test - the test response values (must be categorical) | |
OUTPUT: | |
auc - returns auc as a float | |
prints the roc curve | |
''' | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from itertools import cycle | |
from sklearn.metrics import roc_curve, auc, roc_auc_score | |
from scipy import interp | |
y_preds = model.fit(X_train, y_train).predict_proba(X_test) | |
# Compute ROC curve and ROC area for each class | |
fpr = dict() | |
tpr = dict() | |
roc_auc = dict() | |
for i in range(len(y_test)): | |
fpr[i], tpr[i], _ = roc_curve(y_test, y_preds[:, 1]) | |
roc_auc[i] = auc(fpr[i], tpr[i]) | |
# Compute micro-average ROC curve and ROC area | |
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_preds[:, 1].ravel()) | |
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) | |
plt.plot(fpr[2], tpr[2], color='darkorange', | |
lw=2, label='ROC curve (area = %0.2f)' % roc_auc[2]) | |
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') | |
plt.xlim([0.0, 1.0]) | |
plt.ylim([0.0, 1.05]) | |
plt.xlabel('False Positive Rate') | |
plt.ylabel('True Positive Rate') | |
plt.title('Receiver operating characteristic example') | |
plt.show() | |
return roc_auc_score(y_test, np.round(y_preds[:, 1])) | |
# Finding roc and auc for the random forest model | |
build_roc_auc(rf_mod, training_data, testing_data, y_train, y_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment