Skip to content

Instantly share code, notes, and snippets.

@luiarthur
Last active February 16, 2023 02:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save luiarthur/32527c084e08d3979233833a43b899e9 to your computer and use it in GitHub Desktop.
Save luiarthur/32527c084e08d3979233833a43b899e9 to your computer and use it in GitHub Desktop.
Elliptical slice sampling in R
ProgressBar = function(niters, freq=0.2) {
tic = Sys.time()
init = Sys.time()
advance = function(i) {
toc = Sys.time()
if ((toc - tic > freq) || (i == niters)) {
.speed = i / as.numeric(toc - init)
speed = round(ifelse(.speed > 1, .speed, 1 / .speed))
speed_units = ifelse(.speed > 1, "it/s", "s/it")
cat("\rProgress: ", i, "/", niters, " | speed: ", speed, speed_units, sep = "")
tic <<- toc
if (i == niters) {
cat("\nDuration: ", toc - init, "seconds.\n")
}
}
}
list(advance = advance)
}
# One step of ESS.
ess_step = function(state, loglike_fn, prior_sampler) {
nu = prior_sampler()
u = runif(1)
prev_loglike = loglike_fn(state)
log_y = prev_loglike + log(u)
theta = runif(1, 0, 2 * pi)
theta_min = theta - 2 * pi
theta_max = theta
while(TRUE) {
cand = state * cos(theta) + nu * sin(theta)
loglike_cand = loglike_fn(cand)
if (loglike_cand > log_y) {
state = cand
prev_loglike = loglike_cand
return(state)
} else {
if (theta < 0) {
theta_min = theta
} else {
theta_max = theta
}
theta = runif(1, theta_min, theta_max)
}
}
}
# ESS class.
ESS = function(loglike_fn, prior_sampler) {
fit = function(nmcmc, burn=0, thin=1, init=NA) {
if (is.na(init)) init = prior_sampler()
state = init
total_iters = nmcmc * thin + burn
chain = matrix(0, nmcmc, length(init))
j = 0
pb = ProgressBar(total_iters)
for (i in 1:total_iters) {
state = ess_step(state, loglike_fn, prior_sampler)
pb$advance(i)
if ((i > burn) && ((i - burn) %% thin == 0)) {
j = j + 1
chain[j, ] = state
}
}
return(chain)
}
list(
fit = fit,
loglike_fn = loglike_fn,
prior_sampler = prior_sampler
)
}
# Linear regression model.
make_model = function(X, y) {
nfeatures = ncol(X)
list(
prior_sampler = function() rnorm(nfeatures + 1, 0, 10),
loglike_fn = function(state) {
beta = state[1:nfeatures]
log_sigma = tail(state, 1)
sigma = exp(log_sigma)
mu = X %*% beta
sum(dnorm(y, mu, sigma, log=TRUE))
}
)
}
# Data generator.
make_data = function(nfeatures, nobs, sigma) {
X = matrix(rnorm(nobs * nfeatures), nobs, nfeatures)
beta = rnorm(nfeatures)
mu = X %*% beta
y = rnorm(nobs, mu, sigma)
list(
y = y,
X = X,
sigma = sigma,
beta = beta,
nfeatures = nfeatures
)
}
# Run demo.
set.seed(0)
data = make_data(nfeatures = 5, nobs = 1000, sigma = 0.5)
model = make_model(data$X, data$y)
init = double(data$nfeatures + 1)
ess = ESS(model$loglike_fn, model$prior_sampler)
print(system.time(chain <- ess$fit(1000, burn=1000, thin=5)))
chain[, data$nfeatures + 1] = exp(chain[, data$nfeatures + 1])
colnames(chain) = c(paste0("beta", 1:5), "sigma")
# Print results.
results = rbind(colMeans(chain), c(data$beta, data$sigma))
rownames(results) = c("Post. Mean", "Truth")
print(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment