Last active
March 8, 2023 10:57
-
-
Save SilasK/1f4c1810315c883d8d573a5b8646218d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def auc_analysis(classifier,X,y,n_splits=6,random_state=0,show_plot=True): | |
import matplotlib.pyplot as plt | |
from sklearn import metrics | |
from sklearn.model_selection import KFold | |
# Run classifier with cross-validation and plot ROC curves | |
cv = KFold(n_splits=n_splits) | |
tprs = [] | |
aucs = [] | |
mean_fpr = np.linspace(0, 1, 100) | |
F1s = [] | |
for i, (train, test) in enumerate(cv.split(X, y)): | |
classifier.fit(X[train], y[train]) | |
fpr, tpr, thresholds = metrics.roc_curve(y[test], classifier.decision_function(X[test])) | |
roc_auc = metrics.auc(fpr, tpr) | |
interp_tpr = np.interp(mean_fpr, fpr, tpr) | |
interp_tpr[0] = 0.0 | |
tprs.append(interp_tpr) | |
aucs.append(roc_auc) | |
# calculate mean and std auc | |
mean_tpr = np.mean(tprs, axis=0) | |
mean_tpr[-1] = 1.0 | |
std_auc = np.std(aucs) | |
mean_auc = metrics.auc(mean_fpr, mean_tpr) | |
if show_plot: | |
fig, ax = plt.subplots() | |
ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8) | |
ax.plot( | |
mean_fpr, | |
mean_tpr, | |
color="b", | |
label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc), | |
lw=2, | |
alpha=0.8, | |
) | |
std_tpr = np.std(tprs, axis=0) | |
tprs_upper = np.minimum(mean_tpr + std_tpr, 1) | |
tprs_lower = np.maximum(mean_tpr - std_tpr, 0) | |
ax.fill_between( | |
mean_fpr, | |
tprs_lower, | |
tprs_upper, | |
color="grey", | |
alpha=0.2, | |
label=r"$\pm$ 1 std. dev.", | |
) | |
ax.set( | |
xlim=[-0.05, 1.05], | |
ylim=[-0.05, 1.05], | |
#title="Receiver operating characteristic example", | |
) | |
ax.legend(loc="lower right") | |
#plt.show() | |
return mean_auc, std_auc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn import metrics | |
import pandas as pd | |
import numpy as np | |
def measure_prediction(y_true,pred): | |
assert y_true.shape[0] == pred.shape[0], "ytrue and pred must have the same shape, but have {} and {}".format(y_true.shape,pred.shape) | |
# Note that in binary classification, recall of the positive class is also known as “sensitivity”; recall of the negative class is “specificity”. | |
tn, fp, fn, tp = metrics.confusion_matrix(y_true,pred).ravel() | |
sensitivity = recall = tp / (tp+fn) | |
specificity = tn / (tn+fp) | |
precision = tp / (tp+fp) | |
f1 = 2* (precision * recall) / (precision + recall) | |
if np.isnan(f1): | |
f1=0 | |
return dict( zip( ["True Positive","True Negative","F1","Precision","Recall","Sensitivity","Specificity"], | |
[tp,tn, f1,precision,recall,sensitivity,specificity] ) | |
) | |
def standard_analysis(classifier,X,y,n_splits=6,n_grid=200,random_state=0): | |
from sklearn import metrics | |
from sklearn.model_selection import KFold | |
# Run classifier with cross-validation and plot ROC curves | |
cv = KFold(n_splits=n_splits) | |
Res= {} | |
for i, (train, test) in enumerate(cv.split(X, y)): | |
#classifier.fit(X[train], y[train]) | |
#pred= classifier.predict(X[test]) | |
for th in np.linspace(-5,10,n_grid): | |
pred = X[test]>th | |
Res[(i,th)] = measure_prediction(y[test],pred) | |
Res = pd.DataFrame(Res).T | |
Res.index.names = ("Iteration","Threshold") | |
return Res | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment