Skip to content

Instantly share code, notes, and snippets.

Created September 4, 2023 13:12
Show Gist options
  • Save MaverickMeerkat/838af93586a130347acc59e22a568e5b to your computer and use it in GitHub Desktop.
Save MaverickMeerkat/838af93586a130347acc59e22a568e5b to your computer and use it in GitHub Desktop.
# Expectation Maximization - Gaussian Mixture Model
library(ks) # for kde
library(mvtnorm) # for MVNormal
library(extraDistr) # for Categorical & Dirichlet distribution
library(ggplot2) # for plotting
library(pracma) # for sqrtm
# 1D, 2-components #
## sigma's, p are known; only estimate mu's
p = 0.3
n = 2000
# Latent
z = rbinom(n, 1, p)
# Observed (generated)
mu1 = 2
sig1 = sqrt(2)
mu0 = -2
sig0 = sqrt(1)
x = rnorm(n,mu1,sig1)*z+rnorm(n,mu0,sig0)*(1-z)
# Plot data histogram
hist(x, breaks=50)
# Plot data KDE
# E step
doE = function (theta) {
mu0 = theta[1]
mu1 = theta[2]
E = rep(NA, n)
for (i in 1:n) {
a = dnorm(x[i], mean=mu1, sd=sig1)*p
b = dnorm(x[i], mean=mu0, sd=sig0)*(1-p)
E[i] = a/(a+b)
# M step
doM = function(E) {
theta = rep(NA, 2)
theta[1] = sum((1-E)*x)/sum(1-E)
theta[2] = sum(E*x)/sum(E)
EM = function(theta, maxIter=50, tol=1e-5) {
theta.t = theta
for (i in 1:maxIter) {
E = doE(theta.t)
theta.t = doM(E)
cat("Iteration: ", i, "pt: ", theta.t, "\n")
if (norm(theta.t-theta, type="2") < tol) break
# Initial values
theta0 = rmvnorm(1, mean=c(0,0))
( = EM(theta0))
## estimate p, mu's & sigma's
# E step
doE = function (theta) {
mu0 = theta[1]
sig0 = sqrt(theta[2])
mu1 = theta[3]
sig1 = sqrt(theta[4])
p = theta[5]
E = rep(NA, n)
for (i in 1:n) {
a = dnorm(x[i], mean=mu1, sd=sig1)*p
b = dnorm(x[i], mean=mu0, sd=sig0)*(1-p)
E[i] = a/(a+b)
# M step
doM = function(E) {
theta = rep(NA, 5)
theta[1] = sum((1-E)*x)/sum(1-E)
theta[2] = sum((1-E)*(x-theta[1])^2)/sum(1-E)
theta[3] = sum(E*x)/sum(E)
theta[4] = sum(E*(x-theta[3])^2)/sum(E)
theta[5] = mean(E)
EM = function(theta, maxIter=200, tol=1e-5) {
theta.t = theta
for (i in 1:maxIter) {
E = doE(theta.t)
theta.t = doM(E)
cat("Iteration: ", i, "pt: ", theta.t, "\n")
if (norm(theta.t-theta, type="2") < tol) break
# Initial values
theta0 = c(0,0.5,0,3,0.7) # mu0, sig0.2, mu1, sig1.2, p
( = EM(theta0))
# General Case #
# K x 2D Gaussians
# Params
K = 3
phis = rdirichlet(1, rep(1, K))
j = 2 # how far apart the centers are
mus = matrix(c(j,-j,0,j,j,-j),ncol=2) # rmvnorm(K, mean=c(0,0))
Sigmas = vector("list", K)
for (k in 1:K) {
mat = matrix(rnorm(100), ncol=2)
Sigmas[[k]] = cov(mat)
# Latent
z = rcat(n, phis)
# Observed (generated)
x = matrix(nrow=n, ncol=2)
for (i in 1:n) {
k = z[i]
mu = mus[k,]
sigma = Sigmas[[k]]
x[i,] = rmvnorm(1, mean=mu, sigma=sigma)
# plot the data
df = data.frame(x=x, mu=as.factor(z))
(plt = ggplot(df, aes(x=x[,1], y=x[,2], color=mu, fill=mu)) + geom_point())
# E step
doE = function (theta) {
E = with(theta,, lapply(1:K, function(k) phis[[k]]*dmvnorm(x, mus[[k]], Sigmas[[k]]))))
doM = function(E) {
phis = colMeans(E)
covs = lapply(1:K, function(k) cov.wt(x, E[,k], method="ML"))
mus = lapply(covs, "[[", "center")
sig = lapply(covs, "[[", "cov")
return(list(mus=mus, Sigmas=sig, phis=phis))
logLikelihood = function(theta) {
probs = with(theta,, lapply(1:K, function(i) phis[i] * dmvnorm(x, mus[[i]], Sigmas[[i]]))))
EM = function(theta, maxIter=30, tol=1e-1) {
theta.t = theta
for (i in 1:maxIter) {
E = doE(theta.t)
theta.t = doM(E)
ll.diff = logLikelihood(theta.t) - logLikelihood(theta)
cat("Iteration: ", i, " ll difference: ", ll.diff, "\n")
if (abs(ll.diff) < tol) break
# Initial values
phis0 = rdirichlet(1, rep(1, K))
mus0 = vector("list", K)
Sigmas0 = vector("list", K)
for (k in 1:K) {
mat = matrix(rnorm(100), ncol=2)
mus0[[k]]=rmvnorm(1, mean=c(3,-3))
Sigmas0[[k]] = cov(mat)
theta0 = list(mus=mus0, Sigmas=Sigmas0, phis=phis0)
( = EM(theta0))
# Plot the two distributions
circleFun = function(center=c(0,0), diameter=1, npoints=100){
r = diameter / 2
tt = seq(0,2*pi,length.out = npoints)
xx = center[1] + r * cos(tt)
yy = center[2] + r * sin(tt)
return(data.frame(x = xx, y = yy))
plotCircle = function(plt, center, Sigma, col="#000000") {
dat = circleFun(c(0,0),4,npoints = 100)
dat1 = sweep(as.matrix(dat) %*% sqrtm(Sigma)$B, MARGIN=2, center, "+")
plt = plt + theme_light() + theme(legend.position="none") +
ylim(-6,6) + xlim(-6,6) + xlab("x") + ylab("y") +
geom_point(mapping=aes(x=center[1],y=center[2]), color="#000000") +
geom_polygon(, mapping=aes(x=V1,y=V2), color=col, fill=col, alpha=0.3)
plt = ggplot(df, aes(x=x[,1], y=x[,2], color=mu, fill=mu)) + geom_point()
plt = plotCircle(plt,$mus[[1]],$Sigmas[[1]])
plt = plotCircle(plt,$mus[[2]],$Sigmas[[2]])
plt = plotCircle(plt,$mus[[3]],$Sigmas[[3]])
# D. Refaeli
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment