Skip to content

Instantly share code, notes, and snippets.

@roualdes
Last active November 11, 2022 00:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save roualdes/f6dc7284ca5f8678644529be3f0791a8 to your computer and use it in GitHub Desktop.
Save roualdes/f6dc7284ca5f8678644529be3f0791a8 to your computer and use it in GitHub Desktop.
BridgeStan: Multivariate Student-t
library(cmdstanr)
library(bridgestan)
library(mvtnorm)
project_xy <- function(z, sigma = diag(dim(z)[2])) {
theta <- atan2(z[,2], z[,1]) # atan(z[,2] / z[,1]) + (z[,1] < 0) * pi
r <- apply(z, 1, function(x) sqrt(t(x) %*% solve(sigma, x)))
list(x = r * cos(theta), y = r * sin(theta))
}
# student-t
mod <- cmdstan_model("mvt.stan")
D <- 40
sigma <- diag(D)
data <- list(D = D, nu = 5, mu = rep(0, D), Sigma = sigma)
write_stan_json(data, "mvt_data.json")
fit <- mod$sample(data = data, chains = 4, parallel_chains = 2)
sm <- StanModel$new("mvt_model.so", "mvt_data.json", 204, 1)
png("mcmc.png")
z <- rmvt(4000, sigma = sigma, df = 10)
xy <- project_xy(z, sigma)
x1S <- xy$x
x2S <- xy$y
plot(x1S, x2S, col = "grey80", xlab = "", ylab = "",
main = expression(paste("40D Isotropic Student-", t[10])))
zig <- fit$draws(c("z", "ig"), format = "draws_matrix")
M <- dim(zig)[1]
q <- sm$param_unconstrain(c(zig[M, ]))
metric <- rowMeans(sapply(fit$inv_metric(), diag))
p <- rnorm(2*D) * sqrt(1 / metric)
stepsize <- min(fit$sampler_diagnostics(format = "df")$stepsize__)
## leapfrog
S <- 50
Ts <- matrix(0, nrow = S, ncol = D)
Ts[1,] <- tail(sm$param_constrain(q, include_tp = TRUE), D)
for (s in 2:S) {
gr <- sm$log_density_gradient(q)$gradient
p <- p + gr * stepsize / 2
q <- q + stepsize * p * metric
gr <- sm$log_density_gradient(q)$gradient
p <- p + gr * stepsize / 2
Ts[s, ] <- tail(sm$param_constrain(q, include_tp = TRUE), D)
}
xy <- project_xy(Ts, sigma)
x1L <- xy$x
x2L <- xy$y
points(x1L, x2L, col = "blue", pch=20)
lines(x1L, x2L, col = "blue")
## Metropolis
Ts2 <- matrix(0, nrow = S, ncol = D)
q <- sm$param_unconstrain(c(zig[M, ]))
Ts2[1,] <- tail(sm$param_constrain(q, include_tp = TRUE), D)
for (s in 2:S) {
a <- sm$log_density(q)
pr <- rmvnorm(1, q, sigma = diag(2 * D) * metric / (2 * D))
if (log(runif(1)) < sm$log_density(pr) - a) {
q <- pr
}
Ts2[s,] <- tail(sm$param_constrain(q, include_tp = TRUE), D)
}
xy <- project_xy(Ts2, sigma)
x1M <- xy$x
x2M <- xy$y
points(x1M, x2M, col = "red", pch=20)
lines(x1M, x2M, col = "red")
legend("bottomleft",
c("Metropolis", "Leapfrog"),
col = c("red", "blue"), pch = 20)
dev.off()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment