Skip to content

Instantly share code, notes, and snippets.

@sigma23
Created November 5, 2018 20:08
Show Gist options
  • Save sigma23/425a45930a4ca25088d7c7ffcf623936 to your computer and use it in GitHub Desktop.
Save sigma23/425a45930a4ca25088d7c7ffcf623936 to your computer and use it in GitHub Desktop.
Multiclass Roc curves
from sklearn.metrics import roc_curve, auc
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y[:, i], preds[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y.ravel(), preds.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
for class_num in range(n_classes):
plt.figure()
lw = 2
plt.plot(fpr[class_num], tpr[class_num], color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[class_num])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC - {}'.format(le.inverse_transform([class_num])[0]))
plt.legend(loc="lower right")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment