Skip to content

Instantly share code, notes, and snippets.

@tdunning
Last active December 7, 2020 22:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tdunning/a31265b1d8001dba4222a7939e000414 to your computer and use it in GitHub Desktop.
Save tdunning/a31265b1d8001dba4222a7939e000414 to your computer and use it in GitHub Desktop.
Implementation of Monte Carlo EM algorithm for reconstructing a standard distribution from censored observations
### This is a demonstration of a Monte Carlo Expectation Maximization
### algorithm that can recover the mean and standard deviation of
### truncated normally distributed data. We get 10,000 samples from
### a unit normal distribution, but every sample below 0.5 is truncated
### to that value. Every sample above 2.5 is truncated to that value.
### These choices were made to get quick and visually appealling convergence
### but the algorithm still converges for any choice. The converges
### could be very, very slow if there is little information in the samples
### and the final answer could have substantial uncertainty. For instance,
### if we truncated at 4 and 6, almost all samples would be piled up at
### the lower limit and all we really can know is that the mean is less
### than the lower limit and the standard deviation is a small part of
### the distance from the mean to the lower limit.
### The way that this works is that we keep a current estimate of mean
### and standard deviation. At each step, we replace samples that were
### truncated with samples taken from our currently estimated
### distribution. Then we combine the uncensored samples with these
### synthetic samples and compute the new mean and standard deviation.
### This algorithm works very well even when our initial estimates are
### absurdly wrong.
### for a pretty show, try running demo()
### This code has a few nice tricks. One is the use of the CDF of the
### normal to get a quantile range. We can then take uniform samples
### from that range and use the inverse CDF to get suitably truncated
### samples from the normal. Without this, we might use an MCMC sampler
### based on Metropolis-Hastings. The CDF trick is just much faster and
### only takes a few lines of code.
### truncation bounds
a = 0.5
b = 2.5
### Raw data and truncated data
z.0 = rnorm(10000)
z = pmin(b, pmax(a, z.0))
### Demo example
demo = function() {
## start with crazy estimate
ms = c(5, 1)
for (i in 1:100) {
ms = step(ms, z, a, b, T)
Sys.sleep(0.1)
}
}
### Record frames to show the evolution and then stitch into a video
video = function() {
## start with crazy estimate
ms = c(-5, 0.1)
system("rm -rf frames")
system("mkdir frames")
for (i in 1:200) {
png(sprintf("frames/f-%04d.png", i), 1920, 1080, pointsize=24)
ms = step(ms, z, a, b, T, lwd=15)
dev.off()
}
system("rm mcem.mp4")
system("ffmpeg -r 10 -f image2 -s 1920x1080 -i frames/f-%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p mcem.mp4")
}
### This is where the Monte Carlo E step and the M step are done
step = function(ms, data, a, b, plot=T, lwd=5) {
## Unpack current mean and sd
m = ms[1]
s = ms[2]
## Find the censored samples
low = data <= a
high = data >= b
uncensored = data[(!low) & (!high)]
## Find bounds for resampling the censored data
## we have to avoid returning exactly 0 or 1 here
p.a = max(1e-8, min(1-1e-8, pnorm(a, m, s)))
p.b = max(1e-8, min(1-1e-8, pnorm(b, m, s)))
## transform uniform samples into normally distributed
## these should be in the censored regions, but we
## touch them up a bit if they encroached. That happens
## if our current distribution is crazy
re.a = pmin(a, qnorm(runif(sum(low), 0, p.a), m, s))
re.b = pmax(b, qnorm(runif(sum(high), p.b, 1), m, s))
re.data = c(re.a, uncensored, re.b)
if (plot) {
## plots a histogram on constant range of x axis. Original samples are
## overwritten with red
brk = seq(-5.1, 5.1, by=0.1)
h = hist(re.data[(re.data>-5) & (re.data<5)], breaks=brk, plot=F)
plot(c(), c(), xlim=c(-5,5), ylim=c(0,500), xlab="x", ylab="count")
text(2.8, 300, sprintf("mean = %.2f", ms[1]), adj=0)
text(2.8, 280, sprintf("sd = %.3f", ms[2]), adj=0)
text(2.8, 260, sprintf("(original m=%.2f, sd=%.3f)", mean(z.0), sd(z.0)), adj=0)
legend(2.7, 370, legend=c("Observed data", "Monte Carlo data"),
fill=c('red','black'))
col = rgb((h$mids > a) & (h$mids < b), 0, 0, alpha=0.8)
lines(h$mids, h$counts, type='h', lwd=lwd, col=col, ylim=c(0,500))
}
## compute new estimates
c(mean(re.data), sd(re.data))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment