Last active
July 16, 2018 23:32
-
-
Save gakuba/668b42d853ad9471e2e5675f5601a312 to your computer and use it in GitHub Desktop.
module to display roc_au of multilabel classifier
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.metrics import roc_curve, auc | |
from scipy import interp | |
from itertools import cycle | |
def roc_auc(y_test, y_score, n_classes): | |
"""Plots ROC curve for micro and macro averaging.""" | |
# Compute ROC curve and ROC area for each class | |
fpr = {} | |
tpr = {} | |
roc_auc = {} | |
for i in range(n_classes): | |
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) | |
roc_auc[i] = auc(fpr[i], tpr[i]) | |
# Compute micro-average ROC curve and ROC area | |
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) | |
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) | |
# Compute macro-average ROC curve and ROC area | |
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) | |
mean_tpr = np.zeros_like(all_fpr) | |
for i in range(n_classes): | |
mean_tpr += interp(all_fpr, fpr[i], tpr[i]) | |
mean_tpr /= n_classes | |
fpr["macro"] = all_fpr | |
tpr["macro"] = mean_tpr | |
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) | |
# Plot all ROC curves | |
plt.figure() | |
plt.plot(fpr["micro"], tpr["micro"], | |
label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]), | |
color='deeppink', linestyle=':', linewidth=4) | |
plt.plot(fpr["macro"], tpr["macro"], | |
label='macro-average ROC curve (area = {0:0.2f})'.format(roc_auc["macro"]), | |
color='navy', linestyle=':', linewidth=4) | |
colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |
for i, color in zip(range(0,3), colors): | |
plt.plot(fpr[i], tpr[i], color=color, lw=2, | |
label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i])) | |
plt.plot([0, 1], [0, 1], 'k--', lw=2) | |
plt.xlim([0.0, 1.0]) | |
plt.ylim([0.0, 1.05]) | |
plt.xlabel('False Positive Rate') | |
plt.ylabel('True Positive Rate') | |
plt.title('Some extension of ROC to multi-class') | |
plt.legend(loc="lower right") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment