Skip to content

Instantly share code, notes, and snippets.

@statwonk
Created August 15, 2020 21:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save statwonk/d3c9a747434bfd6d799e74b416c302b2 to your computer and use it in GitHub Desktop.
Save statwonk/d3c9a747434bfd6d799e74b416c302b2 to your computer and use it in GitHub Desktop.
A simulation to learn about variational inference and compare it to MCMC.
library(tidyverse)
library(brms)
library(tidybayes)
3e4 -> N
40 -> K
rnorm(K) -> group_coefs
tibble(K = factor(rep(paste0("group_", seq_len(K)), length.out = N))) %>%
mutate(coef = rep(group_coefs, N/40)) %>%
mutate(p = qlogis(0.1),
y = map2_int(p, coef, ~ rbinom(1, 1, plogis(.x + .y)))) -> d
system.time({
brm(y ~ (1|K), data = d, family = "bernoulli", cores = 4, chains = 4) -> bfit
})
ranef(bfit) %>%
pluck("K") %>%
{ .[, 3:4, 1] } -> x
x %>%
as_tibble() %>%
mutate(group = rownames(x)) %>%
mutate(id = gsub("group_", "", group),
id = as.integer(id)) %>%
arrange(id) %>%
mutate(coefs = group_coefs) %>%
mutate(group = 1:n()) %>%
mutate(group_caught = Q2.5 < coefs & coefs < Q97.5) %>%
count(group_caught)
system.time({
brm(y ~ (1|K), data = d, family = "bernoulli",
cores = 4, chains = 4,
algorithm = "fullrank") -> bfit1 # variational inference
})
ranef(bfit1) %>%
pluck("K") %>%
{ .[, 3:4, 1] } -> x
x %>%
as_tibble() %>%
mutate(group = rownames(x)) %>%
mutate(id = gsub("group_", "", group),
id = as.integer(id)) %>%
arrange(id) %>%
mutate(coefs = group_coefs) %>%
mutate(group = 1:n()) %>%
mutate(group_caught = Q2.5 < coefs & coefs < Q97.5) %>%
count(group_caught)
d %>%
distinct(K) %>%
add_fitted_draws(bfit) %>%
mutate(algo = "MCMC") %>%
bind_rows(
d %>%
distinct(K) %>%
add_fitted_draws(bfit1) %>%
mutate(algo = "Variational inference - fullrank")
) %>%
mutate(id = gsub("group_", "", K),
id = as.integer(id)) %>%
arrange(id) %>%
left_join(tibble(id = seq_len(40), actuals = plogis(qlogis(0.1) + group_coefs))) %>%
mutate(id = factor(id, levels = unique(id))) %>%
ggplot() +
stat_ecdf(aes(x = .value, color = factor(algo))) +
geom_vline(aes(xintercept = actuals)) +
# geom_density() +
facet_wrap(~ K) +
scale_color_discrete(name = "algorithm") +
theme(legend.position = "top") +
ggtitle("hierarchical logistic regression: mcmc vs. variational inference")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment