Skip to content

Instantly share code, notes, and snippets.

@MaverickMeerkat
Last active September 4, 2023 13:27
Show Gist options
  • Save MaverickMeerkat/49c2a966b873f6f72a0c97990463870c to your computer and use it in GitHub Desktop.
Save MaverickMeerkat/49c2a966b873f6f72a0c97990463870c to your computer and use it in GitHub Desktop.
CAVI_Bayesian_GMM.R
# Using the CAVI algorithm on a (Bayesian) GMM example
library(mvtnorm) # for multivariate normal density
library(extraDistr) # for Categorical distribution
library(pracma) # for 2d-integral, sqrtm
library(matlib) # for solving linear equations
library(ggplot2) # for plotting
# Hyper Parameters
sigma = 3 # Variance of mu
K = 5 # No. of clusters
n = 1000 # No. of observations
###############
# 1D Gaussian #
###############
set.seed(247)
mu = rnorm(K, mean=0, sd=sigma) # Means
cs = rcat(n, rep(1/K, K)) # C's
x = rnorm(n, mean=mu[cs], sd=1) # X's
# plot the data
df = data.frame(x=x, mu=as.factor(cs))
ggplot(df, aes(x=x, color=mu, fill=mu)) + geom_histogram(alpha=0.5)
# CAVI algorithm
mk = rnorm(K) # random start values
sk2 = rgamma(K, 5) # random (positive) start values
phis = rdirichlet(n, c(1,1,1,1,1)) # random phis
# rowSums(phis) # sanity check
ELBO = function(mk, sk2, phis) {
t = sk2+mk^2
a = -(1/(2*sigma^2))*sum(t)
b = 2*sum(sweep(sweep(phis, MARGIN=2, mk, '*'), MARGIN=1, x, '*'))-0.5*sum(sweep(phis, MARGIN=2, t, '*'))
c = -0.5*sum(log(2*pi*sk2))
d = sum(phis*log(phis))
return(a+b+c+d)
}
iter = 30
elbos = rep(NA, iter+1)
elbos[1] = ELBO(mk,sk2,phis)
for (i in 1:iter) {
phis.new = matrix(nrow=n, ncol=K)
for (j in 1:n) {
phis.new[j,] = exp(x[j]*mk-0.5*(sk2+mk^2))
phis.new = phis.new/rowSums(phis.new)
}
phis = phis.new
mk.new = rep(NA, K)
sk2.new = rep(NA, K)
for (k in 1:K) {
sk2.new[k] = 1/(1/sigma^2+sum(phis[,k]))
mk.new[k] = sk2.new[k]*sum(phis[,k]*x)
}
sk2 = sk2.new
mk = mk.new
elbos[i+1] = ELBO(mk,sk2,phis)
cat("Iteration: ", i, "ELBO-diff: ", abs(elbos[i+1]-elbos[i]), "\n")
if (abs(elbos[i+1]-elbos[i])<0.1) break
}
ggplot(data=data.frame(Iter=seq(1,i,1), ELBO=elbos[1:i]), aes(x=Iter,y=ELBO)) +
geom_line(color="#2E9FDF")
# CAVI finds clusters center almost perfectly
mk
mu
# There is very little doubt over the centers
sk2
ggplot(df, aes(x=x, color=mu, fill=mu)) +
geom_histogram(alpha=0.5) +
geom_vline(data=data.frame(x=mk), aes(xintercept=x, color=as.factor(c(2,4,1,3,5))),
linetype="dashed", size=1) # +
# geom_vline(data=data.frame(x=mk), aes(xintercept=x+2*sqrt(sk2),
# color=as.factor(c(2,4,1,3,5))), size=1, linetype="dashed") +
# geom_vline(data=data.frame(x=mk), aes(xintercept=x-2*sqrt(sk2),
# color=as.factor(c(2,4,1,3,5))), size=1, linetype="dashed")
###############
# 2D Gaussian #
###############
I = diag(c(1,1))
mu = rmvnorm(K, mean=c(0,0), sigma=sigma^2*I) # Means
cs = rcat(n, rep(1/K, K)) # C's
x = mu[cs,]+rmvnorm(n, mean=c(0,0), sigma=I) # X's
# plot the data
df = data.frame(x=x, mu=as.factor(cs))
ggplot(df, aes(x=x[,1], y=x[,2], color=mu, fill=mu)) + geom_point()
# CAVI algorithm
mk = rmvnorm(K, mean=c(0,0), sigma=I) # random start values
sk2 = rgamma(K, 5) # random (positive) start values
phis = rdirichlet(n, c(1,1,1,1,1)) # random phis
# rowSums(phis) # sanity check
plotClusters = function() {
ggplot(df, aes(x=x[,1], y=x[,2], color=mu)) +
geom_point(alpha=0.3) +
geom_point(data=data.frame(x1=mk[,1], x2=mk[,2], mu=as.factor(c(2,1,5,3,4))),
aes(x=x1, y=x2, color=mu), size = 3, colour = "black") +
geom_point(data=data.frame(x1=mk[,1], x2=mk[,2], mu=as.factor(c(2,1,5,3,4))),
aes(x=x1, y=x2, color=mu), size=2)
}
plotClusters()
iter = 30
for (i in 1:iter) {
phis.new = matrix(nrow=n, ncol=K)
for (j in 1:n) {
for (k in 1:K) {
phis.new[j,k] = exp(t(x[j,])%*%mk[k,]-0.5*(2*sk2[k]+t(mk[k,])%*%mk[k,]))
}
phis.new = phis.new/rowSums(phis.new)
}
phis = phis.new
mk.new = matrix(rep(NA, K*2), ncol=2)
sk2.new = rep(NA, K)
for (k in 1:K) {
sk2.new[k] = 1/(1/sigma^2+sum(phis[,k]))
mk.new[k,] = sk2.new[k]*colSums(phis[,k]*x)
}
sk2 = sk2.new
mk = mk.new
}
# CAVI finds clusters center almost perfectly
mk
mu
plotClusters()
# D. Refaeli ©
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment