Created
February 28, 2020 21:19
-
-
Save jolyon129/6d2b0fa14e2c0d0df5294411009702d8 to your computer and use it in GitHub Desktop.
Draw k-fold roc curves for multiple classifiers and plot them in the same figure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from numpy import interp | |
from sklearn import metrics | |
from sklearn.metrics import precision_recall_fscore_support | |
from sklearn.model_selection import KFold | |
def k_fold_roc_plot(cls, Xs, y, verbose=False): | |
""" Plot 10-fold Mean ROC | |
cls: a classifier model with .fit, .predict and .predict_proba functions | |
Xs: The input data | |
y: The label | |
""" | |
nfold = 10 | |
kf = KFold(n_splits=nfold, shuffle=True, random_state=39) | |
prec = [] | |
recs = [] | |
f1s = [] | |
accs = [] | |
fprs = [] | |
tprs = [] | |
aucs = [] | |
r_2s = [] | |
i = 1 # Used to indicate the index of ROC | |
base_fpr = np.linspace(0, 1, 101) | |
for train, test in kf.split(Xs, y): | |
# Get training and test data | |
Xtr = Xs[train, :] | |
ytr = y[train] | |
Xts = Xs[test, :] | |
yts = y[test] | |
# Fit a model | |
cls.fit(Xtr, ytr) | |
yhat = cls.predict(Xts) | |
yprob = cls.predict_proba(Xts) | |
# Measure performance | |
preci, reci, f1i, _ = precision_recall_fscore_support( | |
yts, yhat, average='binary') | |
fpr, tpr, thresholds = metrics.roc_curve(yts, yprob[:, 1]) | |
auc = metrics.roc_auc_score(yts, yprob[:, 1]) | |
r_2 = metrics.r2_score(yts, yprob[:, 1]) | |
# If verbose, draw the intermediate ROC | |
if verbose: | |
plt.plot(fpr, tpr, lw=2, alpha=0.3, | |
label='ROC fold %d (AUC = %0.2f)' % (i, auc)) | |
i = i + 1 | |
# The following is used for interpolating the average tprs! | |
tprs.append(interp(base_fpr, fpr, tpr)) | |
aucs.append(auc) | |
prec.append(preci) | |
recs.append(reci) | |
f1s.append(f1i) | |
acci = np.mean(yhat == yts) | |
accs.append(acci) | |
r_2s.append(r_2) | |
# Take average values of the metrics | |
precm = np.mean(prec) | |
recm = np.mean(recs) | |
f1m = np.mean(f1s) | |
accm = np.mean(accs) | |
aucm = np.mean(aucs) | |
mean_tpr = np.mean(tprs, axis=0) | |
r_2m = np.mean(r_2s) | |
# Compute the standard errors | |
prec_se = np.std(prec) / np.sqrt(nfold - 1) | |
rec_se = np.std(recs) / np.sqrt(nfold - 1) | |
f1_se = np.std(f1s) / np.sqrt(nfold - 1) | |
acc_se = np.std(accs) / np.sqrt(nfold - 1) | |
auc_se = np.std(aucs) / np.sqrt(nfold - 1) | |
r_2_se = np.std(r_2s) / np.sqrt(nfold - 1) | |
print('%s, AUC = %0.4f, SE=%0.4f' % (cls.__class__.__name__, aucm, auc_se)) | |
print('%s, Precision = %0.4f, SE=%0.4f' % (cls.__class__.__name__, precm, prec_se)) | |
print('%s, Recall = %0.4f, SE=%0.4f' % (cls.__class__.__name__, recm, rec_se)) | |
print('%s, f1 = %0.4f, SE=%0.4f' % (cls.__class__.__name__, f1m, f1_se)) | |
print('%s, Accuracy = %0.4f, SE=%0.4f' % (cls.__class__.__name__, accm, acc_se)) | |
print('%s, R^2 = %0.4f, SE=%0.4f' % (cls.__class__.__name__, r_2m, r_2_se)) | |
plt.plot(base_fpr, mean_tpr, | |
label="%s, Mean AUC = %0.4f" % (cls.__class__.__name__, aucm), lw=2, alpha=1) | |
# Add the organge line | |
plt.plot([0, 1], [0, 1], color='orange', linestyle='--') | |
plt.title('ROC Curve Analysis', fontweight='bold') | |
plt.legend(prop={'size': 10}, loc='lower right') | |
plt.xticks(np.arange(0.0, 1.1, step=0.1)) | |
plt.xlabel('False Positive Rate') | |
plt.yticks(np.arange(0.0, 1.1, step=0.1)) | |
plt.ylabel('True Positive Rate') | |
plt.legend(loc="lower right") | |
return cls | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
classifiers = [LogisticRegression(random_state=500, solver="lbfgs"), | |
GaussianNB(), | |
KNeighborsClassifier() | |
] | |
# Create a figure first! | |
plt.figure(figsize=(6,4),dpi= 100, facecolor='white') | |
# Let them share the same plotting area! | |
for model in classifiers: | |
# k_fold_roc_plot will not create figure | |
# Reuse the figure claimed ahead | |
k_fold_roc_plot(model,X_patient.to_numpy(),y, False) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment