Skip to content

Instantly share code, notes, and snippets.

@benilton
Created May 30, 2017 13:53
Show Gist options
  • Save benilton/3929d703abeb39df5ff1879ec7a1c820 to your computer and use it in GitHub Desktop.
Save benilton/3929d703abeb39df5ff1879ec7a1c820 to your computer and use it in GitHub Desktop.
Algoritmo EM - Mistura de Regressões
## Gerando dados
set.seed(1)
y1 = rnorm(20, 2.5, 1.25)
y2 = rnorm(20, 5, 1.25)
Y = c(y1, y2)
##
data(faithful)
plot(faithful$waiting, faithful$eruptions)
Y = faithful$eruptions
mu1 = 1
mu2 = 7
lambda = .8
sigma2 = 2
LL = function(Y, mu1, mu2, sigma2, lambda){
dens1 = (1-lambda)*dnorm(Y, mu1, sqrt(sigma2))
dens2 = lambda*dnorm(Y, mu2, sqrt(sigma2))
sum(log(dens1+dens2))
}
erro = 1
i = 1
while (erro > 1e-6 & i < 1000){
LL_ant = LL(Y, mu1, mu2, sigma2, lambda)
## points(i, LL_ant )
message("Iteracao: ", i)
dens1 = (1-lambda)*dnorm(Y, mu1, sqrt(sigma2))
dens2 = lambda*dnorm(Y, mu2, sqrt(sigma2))
## Passo-E
z = dens2/(dens1+dens2)
mu1_ant = mu1
mu2_ant = mu2
## Passo-M
fit1 = lm(eruptions~waiting, data=faithful, weights=1-z)
fit2 = lm(eruptions~waiting, data=faithful, weights=z)
mu1 = sum((1-z)*Y)/sum(1-z)
mu2 = sum(z*Y)/sum(z)
plot(faithful$waiting, faithful$eruptions, col = (z>.5)+1, pch=19)
abline(fit1, col=1, lty=2, lwd=2)
abline(fit2, col=2, lty=2, lwd=2)
Sys.sleep(.5)
lambda = mean(z)
sigma2 = (1-lambda)*sum((1-z)*(Y-mu1)^2)/sum(1-z) + lambda*sum(z*(Y-mu2)^2)/sum(z)
i = i+1
LL_nova = LL(Y, mu1, mu2, sigma2, lambda)
erro = abs(LL_nova-LL_ant)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment