Skip to content

Instantly share code, notes, and snippets.

@padamson
Created October 23, 2016 07:19
Show Gist options
  • Save padamson/7d04c669f38b0bd6caa21fa0eeb2d035 to your computer and use it in GitHub Desktop.
Save padamson/7d04c669f38b0bd6caa21fa0eeb2d035 to your computer and use it in GitHub Desktop.
ROC curves for each class of the MNIST 10-class classifier
library(ROCR)
library(dplyr)
mnistResultsDF <- data.frame(actual = mnistTest$label,
fit = mnist.kknn$fit,
as.data.frame(mnist.kknn$prob))
plotROCs <- function(df, digitList) {
firstPlot <- TRUE
legendList <- NULL
for (digit in digitList) {
dfDigit <- df %>%
filter(as.character(actual) == as.character(digit) |
as.character(fit) == as.character(digit)) %>%
mutate(prediction = (as.character(actual) == as.character(fit)))
pred <- prediction(dfDigit[,digit+3], dfDigit$prediction)
perf <- performance(pred, "tpr", "fpr")
auc <- performance(pred, "auc")
legendList <- append(legendList,
paste0("Digit: ",digit,", AUC: ",
round(auc@y.values[[1]], digits = 4)))
if (firstPlot == TRUE) {
plot(perf, colorize = FALSE, lty = digit+1, col = digit+1)
firstPlot <- FALSE
} else {
plot(perf, colorize = FALSE, add = TRUE, lty = digit+1, col = digit+1)
}
}
legend(x=0.4, y=0.6,
legend = legendList,
col = 1:10,
lty = 1:10,
bty = "n")
}
plotROCs(mnistResultsDF, 0:9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment