Skip to content

Instantly share code, notes, and snippets.

@pearcemc
Last active November 4, 2021 00:03
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pearcemc/6590844 to your computer and use it in GitHub Desktop.
Save pearcemc/6590844 to your computer and use it in GitHub Desktop.
EM algorithm for a binomial mixture model (arbitrary number of mixture components, counts etc).
##################################################################
#
# APPLYING EM TO A BINOMIAL MIXTURE MODEL
# matthew@refute.me.uk
# Licence: MIT Licence (http://opensource.org/licenses/MIT)
#
##################################################################
###### EM ALGORITHM FUNCTIONS
# E-step
# probability components for indicator variables
probs <- function(i,m,A,Q,n,Y){choose(n[i],Y[i]) * Q[m]^Y[i] * (1 - Q[m])^(n[i]-Y[i]) * A[m]}
# chaining across columns
doCol <- function(m,A,Q,n,Y){sapply(1:length(Y),function(i){probs(i,m,A,Q,n,Y)})}
# generate the expected value of the 'hidden' coin used in each experiment
mu.update <- function(A,Q,n,Y){
unnorm <- sapply(1:length(A),function(i){doCol(i,A,Q,n,Y)})
norms <- apply(unnorm,1,sum)
mu <- unnorm / norms # Expected value of the indicator variables
mu
}
# M-step
# update mixture fractions
A.update <- function(mu,N){apply(mu,2,sum)/N}
# update mixture distribution parameters
Q.update <- function(mu,n,Y){apply(mu * Y,2,sum)/apply(mu * n,2,sum)}
# EM - R is the number of iterations, A and Q are the initial mixture fractions and mixture parameters respectively
doEM <- function(R,A,Q,n,Y){
for(i in 1:R){
mu <- mu.update(A,Q,n,Y)
A <- A.update(mu,length(Y))
Q <- Q.update(mu,n,Y)
}
list("mu"=mu, "A"=A, "Q"=Q)
}
# helper function to threshold the expected values of hidden variables into an estimate of which coin it was
thresh <- function(EM.result){
Q.sorted <- sort(EM.result$Q,index.return=T) #we reorder the coins so we can compare to the true data later
apply(EM.result$mu,1,function(r){Q.sorted$ix[sort(r,index.return=TRUE)$ix[J]]})
}
########### CREATE BINOMIAL MIXTURE DATA, APPLY EM, LOOK AT RESULTS
# Generate the data
M <- 3 # Number of coins in bag
P <- sort(rbeta(M,3,3)) # The probability of a head for each coin
# NB: if substituting other true probs, sort them in order of probability
N <- 3000 # Number of observed sets of coin tosses
gamms <- rgamma(M,2,1)
fracs <- gamms/sum(gamms) # Slightly centred Dirichlet RV for mixture fractions
C.true <- sample(1:M,N,replace=T,prob=fracs) # The true coin used in each toss - consider this 'hidden'
P.used <- sapply(C.true,function(i){P[i]}) # Corresponding probabilities for true coins - for convenience
n <- rpois(N,50) # Number of trials parameter for each binomial - given
Y <- rbinom(N,n,P.used) # The observed binomial count data
# Initialisation of EM
J <- 3 # Number of coins _suspected_ to be in bag
Q0 <- runif(J,.25,.75) # Randomised corresponding starting probabilities for each coin
A0 <- rep(1/J,J) # Initial expectation of coin proportions
# Run EM
EM.result <- doEM(30,A0,Q0,n,Y)
# Look at results - recall in this section that although the model is not identifiable in the coin ordering, we compensate by ordering
C.hat <- thresh(EM.result) # Tresholded imputation of coin used, ordered by increasing probability.
cbind(C.true, C.hat, C.true==C.hat)
EM.Q <- sort(EM.result$Q,index.return=T) # Look at the sorted estimated coin probabilities and compare with sorted true
Q <- EM.Q$x
Q
sort(P)
A <- EM.result$A[EM.Q$ix] # Look at coin mixture fracs sorted by corresponding probabilities
A
fracs[sort(P,index.return=T)$ix]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment