Skip to content

Instantly share code, notes, and snippets.

@MatsuuraKentaro
Last active December 17, 2020 02:21
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MatsuuraKentaro/bccf13af3ba52c9d6c379c0032725b91 to your computer and use it in GitHub Desktop.
Save MatsuuraKentaro/bccf13af3ba52c9d6c379c0032725b91 to your computer and use it in GitHub Desktop.
replica exchange MCMC
library(rstan)
library(doParallel)
replica.exchange.mcmc <- function (inv_T, n_ex, stanmodel, data, par_list, init, iter, warmup) {
n_rep <- length(inv_T)
len <- iter - warmup
n_param <- sum(unlist(lapply(par_list, prod))) + 2 # number of parameters included E and lp__
ms_T1 <- matrix(0, len*n_ex, n_param) # MCMC samples at inv_T=1
idx_tbl <- matrix(0, n_ex, n_rep) # index table of (exchange time, replica)
E_tbl <- matrix(0, n_ex, n_rep) # E table along idx_tbl
init_list <- rep(list(init), n_rep)
idx4ex <- function (n_rep, e) if (e %% 2 == 0) 1:floor(n_rep/2) * 2 - 1 else 1:(floor(n_rep/2)-1) * 2
for (e in seq_len(n_ex)) {
fit_list <- foreach(r=seq_len(n_rep), .packages='rstan') %dopar% {
data$Inv_T <- inv_T[r]
sampling(
stanmodel, data=data, pars=c(names(par_list), 'E'), init=list(init_list[[r]]),
iter=iter, warmup=warmup, chains=1, seed=r, refresh=-1
)
}
ms_T1[((e-1)*len+1):(e*len), ] <- extract(fit_list[[1]], permuted=FALSE, inc_warmup=FALSE)[,1,]
# exchange replicas
E <- sapply(1:n_rep, function(r) extract(fit_list[[r]], permuted=FALSE, pars='E')[len,1,])
idx <- 1:n_rep
for (r in idx4ex(n_rep, e)) {
w <- exp((inv_T[r] - inv_T[r+1]) * (E[r] - E[r+1]))
if (runif(1,0,1) < w) {
idx[r] <- r+1
idx[r+1] <- r
}
}
E_tbl[e,] <- E
idx_tbl[e,] <- idx
# update init_list
init_list <- lapply(seq_len(n_rep), function(r) {
ms <- extract(fit_list[[idx_tbl[e,r]]], permuted=FALSE, pars=names(par_list))[len,1,]
init <- lapply(names(par_list), function(p) {
pos <- grep(paste0('^', p, '(\\[.*\\])?$'), names(ms))
if (identical(par_list[[p]], 1)) unname(ms[pos]) else array(ms[pos], dim=par_list[[p]])
})
names(init) <- names(par_list)
init
})
}
colnames(ms_T1) <- names(fit_list[[1]])
return(list(ms_T1=ms_T1, idx_tbl=idx_tbl, E_tbl=E_tbl))
}
source('generate-data.R')
data <- list(N=N, X=X, Y=Y)
init <- list(b=24.17, s_y=0.4048)
stanmodel <- stan_model(file='model/model.stan')
# fit <- sampling(stanmodel, data=c(data, Inv_T=1), seed=10)
N_rep <- 10 # number of replicas
N_ex <- 100 # number of exchanges
Inv_T <- 0.5^seq(0, -log(0.02)/log(2), len=N_rep)
registerDoParallel(3)
res <- replica.exchange.mcmc(inv_T=Inv_T, n_ex=N_ex,
stanmodel=stanmodel, data=data, par_list=list(b=1, s_y=1), init=init, iter=70, warmup=50)
source('print-result.R')
stopImplicitCluster()
set.seed(123)
N <- 50
b <- 0.6
s_y <- 0.4
X <- seq(from=0.1, to=4*pi, length=N)
Y <- sin(b * X) + rnorm(N, 0, s_y)
data {
int<lower=1> N;
vector[N] Y;
vector[N] X;
real<lower=0> Inv_T;
}
parameters {
real<lower=0> b;
real<lower=0> s_y;
}
transformed parameters {
real E;
{
vector[N] mu;
for (n in 1:N)
mu[n] <- sin(b * X[n]);
E <- 0;
E <- E - normal_log(b, 0, 50);
E <- E - student_t_log(s_y, 4, 0, 5);
E <- E - normal_log(Y, mu, s_y);
}
}
model {
increment_log_prob(-Inv_T * E);
}
library(ggplot2)
ix <- expand.grid(1:N_rep, 1:N_ex)
d_ex <- as.data.frame(matrix(ncol=3, nrow=2*nrow(ix)))
colnames(d_ex) <- c('x', 'y', 'g')
for (i in 1:nrow(ix)) {
r <- ix[i,1]
e <- ix[i,2]
d_ex[2*i-1, ] <- c(e, Inv_T[r], i)
d_ex[2*i, ] <- c(e+1, Inv_T[res$idx_tbl[e,r]], i)
}
p <- ggplot(data=d_ex, aes(x=x, y=y, group=g))
p <- p + theme(text=element_text(size=18)) + labs(x='Exchange Time', y='Inverse T')
p <- p + geom_line()
ggsave(p, file='output/exchange-invT.png', dpi=300, w=8, h=6)
d_E <- reshape2::melt(res$E_tbl)
colnames(d_E) <- c('Exchange', 'Replica', 'E')
d_E$Replica <- as.factor(d_E$Replica)
p <- ggplot(data=d_E, aes(x=Exchange, y=E, group=Replica, color=Replica))
p <- p + theme(text=element_text(size=18)) + labs(x='Exchange Time', y='E')
p <- p + geom_line()
ggsave(p, file='output/energy-whole.png', dpi=300, w=8, h=6)
p <- p + ylim(NA, 100)
ggsave(p, file='output/energy-zoom.png', dpi=300, w=8, h=6)
d_ms <- reshape2::melt(res$ms_T1)
colnames(d_ms) <- c('Step', 'parameter', 'value')
p <- ggplot(data=d_ms, aes(x=Step, y=value))
p <- p + theme(text=element_text(size=18))
p <- p + facet_wrap(~parameter, scales='free_y')
p <- p + geom_line()
ggsave(p, file='output/traceplot.png', dpi=300, w=8, h=6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment