Skip to content

Instantly share code, notes, and snippets.

@padamson
Created October 23, 2016 07:04
Show Gist options
  • Save padamson/42044ba43a115b5f606835740122acb4 to your computer and use it in GitHub Desktop.
Save padamson/42044ba43a115b5f606835740122acb4 to your computer and use it in GitHub Desktop.
Plot the confusion matrix for a 10-class MNIST handwritten digit classification problem
library(caret)
library(kknn)
library(RColorBrewer)
library(cowplot)
mnist <- read.csv("data/mnist_small.csv",
colClasses = c(label = "factor"))
trainIndex <- createDataPartition(mnist$label, p = .8,
list = FALSE,
times = 1)
mnistTrain <- mnist[ trainIndex,]
mnistTest <- mnist[-trainIndex,]
mnist.kknn <- kknn(label~., mnistTrain, mnistTest, distance = 1,
kernel = "triangular")
confusionDF <- data.frame(confusionMatrix(fitted(mnist.kknn),mnistTest$label)$table)
confusionDF$Reference = with(confusionDF,
factor(Reference, levels = rev(levels(Reference))))
jBuPuFun <- colorRampPalette(brewer.pal(n = 9, "BuPu"))
paletteSize <- 256
jBuPuPalette <- jBuPuFun(paletteSize)
confusionPlot <- ggplot(
confusionDF, aes(x = Prediction, y = Reference, fill = Freq)) +
theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) +
geom_tile() +
labs(x = "Predicted digit", y = "Actual digit") +
scale_fill_gradient2(
low = jBuPuPalette[1],
mid = jBuPuPalette[paletteSize/2],
high = jBuPuPalette[paletteSize],
midpoint = (max(confusionDF$Freq) + min(confusionDF$Freq)) / 2,
name = "") +
theme(legend.key.height = unit(2, "cm"))
ggdraw(switch_axis_position(confusionPlot, axis = 'x'))
@uMklami
Copy link

uMklami commented Nov 24, 2018

And secondly, how can i print the value rather that putting the colors in confusion matrx?

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment