Skip to content

Instantly share code, notes, and snippets.

@chelseaparlett
Last active April 15, 2020 16:34
Show Gist options
  • Save chelseaparlett/a3427ba2290c4e4ba84bdf160220ec26 to your computer and use it in GitHub Desktop.
Save chelseaparlett/a3427ba2290c4e4ba84bdf160220ec26 to your computer and use it in GitHub Desktop.
library(tidyverse)
calculateWeights <- function(dfData, row){
wi <- dnorm(dfData$x, row[,"means"], row[,"var"])
return(wi)
}
normalize.choose <- function(weights, dfDist){
wis <- sapply(1:ncol(weights),
function(x) (weights[,x]*dfDist[,"ak"])/sum(weights[,x]*dfDist[,"ak"]))
chosen <- sapply(1:ncol(weights), function(x) which.max(wis[,x]))
return(list(chosen = chosen, weights = wis))
}
E <- function(dfData, dfDist){
weights <- t(sapply(1:nrow(dfDist),
function(x) calculateWeights(dfData, dfDist[x,])))
chose <- normalize.choose(weights,dfDist)
dfData$assign <- chose$chosen
return(list(dfData = dfData, weights = chose$weights))
}
M <- function(dfData, dfDist, weights){
w <- t(weights)
dfDist[,"nk"] <- sapply(1:ncol(w), function(x) sum(w[,x]))
mus <- sapply(1:ncol(w), function(c) (1/dfDist[c,"nk"]) * sum(w[,c]*dfData$x))
sds <- sapply(1:ncol(w), function(c) sqrt((1/dfDist[c,"nk"]) * sum((w[,c]*(dfData$x - mus[c])**2))))
dfDist[,"means"] <- mus
dfDist[, "var"] <- sds
dfDist[,"ak"] <- dfDist[,"nk"]/nrow(w)
return(dfDist)
}
generatePoints <- function(mu = c(0,2), sd = c(1,0.5),
n = c(7,7)){
m <- mapply(FUN = rnorm, n,mu,sd)
m <- as.vector(m)
return(data.frame(x = m))
}
chooseStartingPoints <- function(seed = 234, n = 2,d){
set.seed(seed)
return(base::sample(d$x,2,replace = F))
}
convergence <- function(dfData,dfDist){
ll <- sapply(dfData$x, function(i) log(sum(dfDist[,"ak"] * dnorm(i, mean = dfDist[,"means"], dfDist[,"var"]))))
ll <- sum(ll)
return(ll)
}
plotDists <- function(dfData,dfDist, title = ""){
colors <- c("lightblue", "darkgreen", "red", "blue")
min_d <- min(dfData$x)
max_d <- max(dfData$x)
s <- seq(min_d - sd(dfData$x), max_d + sd(dfData$x), length = 1000)
plot(x = s, dnorm(s, dfDist[1,"means"], dfDist[1,"var"]), type = "l",
ylim = c(0,2), col = "lightblue", ylab = "density", xlab = "x",
main = title, lwd = 3)
sapply(2:nrow(dfDist), function(x) lines(x = s,
dnorm(s, dfDist[x,"means"], dfDist[x,"var"]),
lty = x, lwd = 3, col = colors[x] ))
points(x = dfData$x, y = rep(0, length(dfData$x)), bg = colors[dfData$assign], pch = 21, cex = 1.25)
}
EM <- function(n, tol = 0.00001){
ll.diff <- 999
curr.ll <- 999
d <- generatePoints()
st <- chooseStartingPoints(n = n,d = d)
dfDist <- data.frame(means = st, var = rep(1,n), ak = rep(1/n, n))
dfData <- data.frame(x = d)
it <- 0
while(ll.diff > tol){
it <- it + 1
#E
estep <- E(dfData,dfDist)
dfData <- estep$dfData
weights <- estep$weights
png(filename= paste0("/Users/cparlett/Desktop/EM/",it*10, ".png"),
width = 800, height = 480)
plotDists(dfData,dfDist,paste0("E",it))
dev.off()
#M
dfDist <- M(dfData,dfDist,weights)
#Converge
new.ll <- convergence(dfData,dfDist)
ll.diff <- abs(curr.ll - new.ll)
curr.ll <- new.ll
print(paste0("On iter ", it, " LL was ", curr.ll))
print(dfDist)
png(filename= paste0("/Users/cparlett/Desktop/EM/",it*10+5, ".png"),
width = 800, height = 480)
plotDists(dfData,dfDist,paste0("M",it))
dev.off()
}
plotDists(dfData,dfDist,"FINAL")
}
EM(n = 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment