Skip to content

Instantly share code, notes, and snippets.

@karlrohe
Created January 7, 2022 22:48
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save karlrohe/70eab491de35ca24e21b74bed3642df8 to your computer and use it in GitHub Desktop.
Save karlrohe/70eab491de35ca24e21b74bed3642df8 to your computer and use it in GitHub Desktop.
PCA on mnist handwritten 2's
# PCA on n=6990 images of handwritten 2's, each with d = 784 pixels.
# install.packages("remotes")
# remotes::install_github("jlmelville/snedata")
# thank you jlmelville for making this data so easy to access!
library(snedata)
library(magrittr)
library(Matrix)
library(rARPACK)
# get the data:
mnist <- download_mnist()
# code to plot an image of a hand written digit:
show_digit <- function(arr784, col=gray(12:1/12), ...) {
# I fiddled with this code:
# source("https://gist.githubusercontent.com/brendano/39760/raw/22467aa8a5d104add5e861ce91ff5652c6b271b6/gistfile1.txt")
# thank you brendano on github!
image((matrix(as.vector(as.numeric(arr784)[1:784]), nrow=28)[,28:1]), col=col, ...)
}
# here is one example:
show_digit(mnist[1,])
# Here are 25 examples:
par(mfrow = c(5,5), mar = c(0,0,0,0),
xaxt='n',
yaxt='n',
ann=FALSE)
for(i in 1:25) show_digit(mnist[i,])
images_of_selected_digit = mnist$Label %in% c("2") # select the two's
# images_of_selected_digit = mnist$Label %in% c("8") # select the two's
x = mnist[images_of_selected_digit,1:784] %>% as.matrix
# this matrix has 6990 rows and 784 columns.
dim(x)
# note: even though we think of images as rectangle-shaped and we also think
# of matrices as rectangle-shaped, here we are using each image is a *vector*
# By doing this, we discard some information... that some pixels are next to one another.
# compute PCs.
# I use rARPACK because it will be much faster, when we only need a couple PCs
# I do not scale the data because some pixels have zero variance.
# others have very very small variance.
# scaling would divide by zero and these small numbers.
# this would be a bad idea!
# Because we don't scale, only center,
# we study the "covariance" instead of "correlation"
s = scale(x, scale = F) %>% rARPACK::svds(k = 2)
# if you want the screeplot, you will want more k...
# s100 = scale(x, scale = F) %>% rARPACK::svds(k = 100)
# plot(s100$d^2)
# this is the first pc:
pc1 = s$u[,1]
dd = density(pc1) # kernel density estimate of the Xhat's
# the first 784 columns are pixels. The last column is pc1
dat = cbind(x,s$u[,1])
vals = seq(from = min(pc1), to = max(pc1), len = 100)
dif = dd$bw # this is the bandwidth from the kernel density estimate
for(i in 1:200){
frame = i
if(i>100){
i = 200-i
}
#identify the images in a region around vals[i]
lower_pc = vals[i]-dif
upper_pc = vals[i]+dif
which_images = (dat[,785] > lower_pc) & (dat[,785] < upper_pc)
# if there are images to plot, then plot them:
if(sum(which_images)>0){
png(file = paste("video/",frame,".png", sep=""),width = 1100/2, height = 850/2)
par(mar = rep(0,4), mfrow = c(1,2))
if(sum(which_images)>1) show_digit(colMeans(dat[which_images,1:784]))
if(sum(which_images)==1) show_digit(dat[which_images,1:784])
plot(dd, main = "", ylab = "", yaxt = "n")
lines(lower_pc*c(1,1), c(-99999,99999), col = "grey")
lines(upper_pc*c(1,1), c(-99999,99999), col = "grey")
}
dev.off()
}
# To make the video, I go to quickTime -> File -> open image sequence.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment