Skip to content

Instantly share code, notes, and snippets.

@jsks
Last active June 20, 2023 13:54
Show Gist options
  • Save jsks/a1ea25b1b962e7ccbf756a5c655c7299 to your computer and use it in GitHub Desktop.
Save jsks/a1ea25b1b962e7ccbf756a5c655c7299 to your computer and use it in GitHub Desktop.
Expectation maximization in R
#!/usr/bin/env Rscript
#
# Gaussian mixture model fit with the expectation-maximization algorithm
###
library(mclust)
library(mvtnorm)
set.seed(4323)
N <- 1000
D <- 2
K <- 3
###
# Sampling functions
array2list <- \(a) apply(a, 3, identity, simplify = F)
rsimplex <- function(n) {
v <- runif(n)
v / sum(v)
}
###
# Simulate dataset
(mu <- replicate(K, sample(-10:10, 2, replace = T), simplify = F))
(sigma <- rWishart(K, D, diag(1, D)) |> array2list())
(pi <- rsimplex(K))
X <- matrix(0, nrow = N, ncol = D)
Z <- sample(1:K, size = N, prob = pi, replace = T)
for (n in 1:N) {
X[n,] <- rmvnorm(1, mean = mu[[Z[n]]], sigma = sigma[[Z[n]]])
}
###
# EM functions
log_lik <- function(X, mu, sigma, pi) {
K <- length(mu)
sapply(1:K, \(k) pi[k] * (dmvnorm(X, mu[[k]], sigma[[k]]))) |>
rowSums() |>
log() |>
sum()
}
E_step <- function(X, mu, sigma, pi) {
N <- nrow(X)
K <- length(mu)
gamma <- matrix(0, nrow = N, ncol = K)
for (k in 1:length(pi))
gamma[, k] <- pi[k] * dmvnorm(X, mu[[k]], sigma[[k]])
gamma / (rowSums(gamma) + 1e-10)
}
M_step <- function(X, gamma) {
N <- nrow(X)
K <- ncol(gamma)
N_k <- colSums(gamma)
mu <- vector("list", K)
sigma <- vector("list", K)
pi <- numeric(K)
for (k in 1:K) {
mu[[k]] <- colSums(gamma[, k] * X) / N_k[k]
X_centered <- sweep(X, 2, mu[[k]])
sigma[[k]] <- crossprod(X_centered, X_centered * gamma[, k]) / N_k[k]
pi[k] <- N_k[k] / N
}
return(list(mu = mu, sigma = sigma, pi = pi))
}
EM <- function(X, K, max_iter = 1e4, tol = 1e-6) {
N <- nrow(X)
D <- ncol(X)
mu <- replicate(K, X[sample(nrow(X), 1), ], simplify = F)
sigma <- replicate(K, cov(X), simplify = F)
pi <- rep(1 / K, K)
ll_old <- -Inf
for (i in 1:max_iter) {
if (i %% 10 == 0)
sprintf("Iteration: %d", i) |> print()
gamma <- E_step(X, mu, sigma, pi)
result <- M_step(X, gamma)
mu <- result$mu
sigma <- result$sigma
pi <- result$pi
ll_new <- log_lik(X, mu, sigma, pi)
if (abs(ll_new - ll_old) < tol)
break
if (i == max_iter)
stop("Model failed to converge.")
ll_old <- ll_new
}
list(mu = mu, sigma = sigma, pi = pi, gamma = gamma, log_lik = ll_new)
}
###
# Fit with Expectation-Maximization
fits <- lapply(1:10, function(i) {
print(paste("Model Run:", i))
EM(X, K)
})
result <- fits[[which.max(sapply(fits, `[[`, "log_lik"))]]
print(result[c("mu", "sigma", "pi")])
z_hat <- apply(result$gamma, 1, which.max)
plot(X[, 1], X[, 2], xlim = c(min(X[, 1]) - 1, max(X[, 1]) + 1),
ylim = c(min(X[, 2]) - 1, max(X[, 2]) + 1))
points(X[z_hat == 1, 1], X[z_hat == 1, 2], col = "blue")
points(X[z_hat == 2, 1], X[z_hat == 2, 2], col = "red")
points(X[z_hat == 3, 1], X[z_hat == 3, 2], col = "green")
# Fit with mclust package
mcl.model <- Mclust(X, 3)
print(mcl.model$parameters)
plot(mcl.model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment