Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active July 31, 2016 00:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/7a30ecbc73ce2bb809c2f360627574e7 to your computer and use it in GitHub Desktop.
Save brandonwillard/7a30ecbc73ce2bb809c2f360627574e7 to your computer and use it in GitHub Desktop.
Inverse Mean Reparameterization Simulation
#
# Simple inspection of an inverse mean reparameterization for observations
# x ~ N(theta, 1)
# and theta = 1/u.
# Original discussion: https://xianblog.wordpress.com/2016/07/15/the-curious-incident-of-the-inverse-of-the-mean/
#
# Authors: Jyotishka Datta and Brandon T. Willard
#
library(rstan)
rstan_options(auto_write = TRUE)
cauchy.code = "
data {
int<lower=0> J;
vector[J] Y;
}
parameters {
real<lower=0> u[J];
}
transformed parameters {
real<lower=0> u_inv[J];
for (j in 1:J){
u_inv[j] <- 1/u[j];
}
}
model {
for (j in 1:J){
u[j] ~ cauchy(0, 1);
Y[j] ~ normal(u_inv[j], 1);
}
}
"
cauchy.fit = stan_model(model_code=cauchy.code, model_name="Cauchy")
gamma.code = "
data {
int<lower=0> J;
vector[J] Y;
}
parameters {
real<lower=0> u[J];
}
transformed parameters {
real<lower=0> u_inv[J];
for (j in 1:J){
u_inv[j] <- 1/u[j];
}
}
model {
for (j in 1:J){
u[j] ~ gamma(3, 1./1000);
Y[j] ~ normal(u_inv[j], 1);
}
}
"
gamma.fit = stan_model(model_code=gamma.code, model_name="gamma")
seed.val = 495
Ys = seq(-2, 30, length.out = 20)
u.means.data = NULL
for (y in Ys) {
c.data = list('J'=1, 'Y' = as.array(y))
stan.iters = 2000
for (algo in c('NUTS')) {
cauchy.res = sampling(cauchy.fit,
data = c.data,
iter = stan.iters,
algorithm = algo,
warmup = floor(stan.iters/2),
thin = 2,
pars = c('u', 'u_inv'),
#init = 0,
seed = seed.val,
chains = 1)
# rstan::extract(cauchy.res, pars=c("lp__"), permuted=TRUE)[[1]]
u.cauchy.stats = summary(cauchy.res)$summary[1,]
u.means.data = rbind(u.means.data,
data.frame(x=y,
var='u',
mean=u.cauchy.stats[6],
low=u.cauchy.stats[4],
high=u.cauchy.stats[8],
prior="cauchy",
algo=algo))
u_inv.cauchy.stats = summary(cauchy.res)$summary[2,]
u.means.data = rbind(u.means.data,
data.frame(x=y,
var='u_inv',
mean=u_inv.cauchy.stats[6],
low=u_inv.cauchy.stats[4],
high=u_inv.cauchy.stats[8],
prior="cauchy",
algo=algo))
gamma.res = sampling(gamma.fit,
data = c.data,
iter = stan.iters,
algorithm = algo,
warmup = floor(stan.iters/2),
thin = 2,
pars = c('u', 'u_inv'),
#init = 0,
seed = seed.val,
chains = 1)
# rstan::extract(gamma.res, pars=c("u", 'u_inv'), permuted=TRUE)[[1]]
u.gamma.stats = summary(gamma.res)$summary[1,]
u.means.data = rbind(u.means.data,
data.frame(x=y,
var='u',
mean=u.gamma.stats[6],
low=u.gamma.stats[4],
high=u.gamma.stats[8],
prior="gamma",
algo=algo))
u_inv.gamma.stats = summary(gamma.res)$summary[2,]
u.means.data = rbind(u.means.data,
data.frame(x=y,
var='u_inv',
mean=u_inv.gamma.stats[6],
low=u_inv.gamma.stats[4],
high=u_inv.gamma.stats[8],
prior="gamma",
algo=algo))
}
}
library(ggplot2)
library(scales)
# From: http://wresch.github.io/2013/03/08/asinh-scales-in-ggplot2.html
asinh_breaks <- function(x) {
br <- function(r) {
lmin <- round(log10(r[1]))
lmax <- round(log10(r[2]))
lbreaks <- seq(lmin, lmax, by = 1)
breaks <- 10 ^ lbreaks
}
p.rng <- range(x[x > 0], na.rm = TRUE)
breaks <- br(p.rng)
if (min(x) <= 0) {breaks <- c(0, breaks)}
if (sum(x < 0) > 1) {
n.rng <- -range(x[x < 0], na.rm = TRUE)
breaks <- c(breaks, -br(n.rng))
}
return(sort(breaks))
}
asinh_trans <- function() {
trans_new("asinh",
transform = asinh,
inverse = sinh,
breaks = asinh_breaks)
}
u.means.plt = ggplot(u.means.data, aes(x=x, y=mean,
color=interaction(prior, algo))) + geom_line()
u.means.plt = u.means.plt + geom_ribbon(aes(ymin=low, ymax=high,
fill=interaction(prior, algo)),
alpha=0.5)
u.means.plt = u.means.plt + facet_grid(var~., scales='free_y')
u.means.plt = u.means.plt + geom_line(data=data.frame(var='u_inv', x=Ys, mean=Ys, prior=NA, algo=NA),
aes(x=x, y=mean), color='black', alpha=0.5)
u.means.plt = u.means.plt + coord_trans(y=asinh_trans())
print(u.means.plt)
ggsave("u_means_plot.pdf", u.means.plt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment