Created
August 29, 2019 11:09
-
-
Save jackobailey/0982c89326c36b12d6fa6d6f182189be to your computer and use it in GitHub Desktop.
BRMS AMEs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load packages | |
library(brms) | |
library(tidybayes) | |
library(tidyverse) | |
library(magrittr) | |
library(reshape2) | |
library(margins) | |
# Function 1: Compute numerical derivatives for continuous variables ---------- | |
bayes_dydx.default <- function(model, data = NULL, variable, stepsize = 1e-7, re_formula = NULL){ | |
# Get data from model where data = NULL | |
if(is.null(data) == T){ | |
d <- model$data | |
} else { | |
d <- data | |
} | |
# Get outcome from model | |
resp <- model$formula$resp | |
# Omit outcome from data | |
d <- | |
d %>% | |
select(-resp) | |
# Omit random effects from the data if necessary | |
if(is.null(re_formula) == F){ | |
# Get random effects | |
rnfx <- unique(model$ranef$group) | |
# Omit from data | |
d <- | |
d %>% | |
select(-rnfx) | |
} | |
# Calculate observed combinations and frequencies to reduce computation time | |
d <- | |
d %>% | |
group_by_all() %>% | |
count(name = "w") %>% | |
ungroup() | |
# Create function to set "h" based on "eps" to deal with machine precision | |
setstep <- function(x) { | |
x + (max(abs(x), 1, na.rm = TRUE) * sqrt(stepsize)) - x | |
} | |
# Calculate numerical derivative | |
d1 <- d0 <- d | |
d0[[variable]] <- d0[[variable]] - setstep(d0[[variable]]) | |
d1[[variable]] <- d1[[variable]] + setstep(d1[[variable]]) | |
# Add fitted draws | |
f0 <- | |
d0 %>% | |
add_fitted_draws(model = model, | |
re_formula = re_formula, | |
value = paste0(variable, "_d0")) | |
f1 <- | |
d1 %>% | |
add_fitted_draws(model = model, | |
re_formula = re_formula, | |
value = paste0(variable, "_d1")) | |
# Calculate average marginal effect | |
out <- | |
f0 %>% | |
ungroup() %>% | |
mutate( | |
me = | |
f1[[paste0(variable, "_d1")]] %>% | |
subtract(f0[[paste0(variable, "_d0")]]) %>% | |
divide_by(d1[[variable]] - d0[[variable]]) | |
) %>% | |
group_by_at(".draw") %>% | |
summarise(ame = sum(me * w)/sum(w)) %>% | |
select(ame) %>% | |
ungroup() | |
# Return AME | |
out %>% | |
mutate(var = variable) %>% | |
melt(id = "var") %>% | |
mutate(variable = var) %>% | |
rename(resp = variable, | |
est = value) | |
} | |
# Function 2: Compute Average Marginal Effect for Factor Variables ---------- | |
bayes_dydx.factor <- function(model, data = NULL, variable, re_formula = NULL){ | |
# Get data from model where data = NULL | |
if(is.null(data) == T){ | |
d <- model$data | |
} else { | |
d <- data | |
} | |
# Get outcome from model | |
resp <- model$formula$resp | |
# Omit outcome from data | |
d <- | |
d %>% | |
select(-resp) | |
# Omit random effects from the data if necessary | |
if(is.null(re_formula) == F){ | |
# Get random effects | |
rnfx <- unique(model$ranef$group) | |
# Omit from data | |
d <- | |
d %>% | |
select(-rnfx) | |
} | |
# Calculate observed combinations and frequencies | |
# to reduce computation time where n is large | |
d <- | |
d %>% | |
group_by_all() %>% | |
count(name = "w") %>% | |
ungroup() | |
# Get factor levels | |
levs <- levels(as.factor(d[[variable]])) | |
base <- levs[1L] | |
cont <- levs[-1L] | |
# Create empty list for fitted draws | |
f <- list() | |
# For each list add fitted draws | |
for (i in seq_along(levs)){ | |
# Fix variable in each list to factor level | |
d[[variable]] <- levs[i] | |
# Add fitted draws, weight, and summarise | |
f[[i]] <- | |
d %>% | |
add_fitted_draws(model = model, | |
re_formula = re_formula, | |
value = "eff") %>% | |
group_by_at(".draw") %>% | |
summarise(eff_w = sum(eff * w)/sum(w)) %>% | |
select(eff_w) %>% | |
ungroup() | |
# Compute contrast if not base level | |
if (i > 1){ | |
f[[i]]$eff_w <- f[[i]]$eff_w - f[[1]][[1]] | |
} | |
# Rename column | |
names(f[[i]]) <- levs[i] | |
} | |
# Remove data frame | |
d <- NULL | |
# Create output object | |
out <- do.call(cbind, f) | |
# Return AMEs | |
if (length(cont) == 1){ | |
out <- out[, cont] %>% tibble() | |
names(out) <- "est" | |
out %>% | |
mutate( | |
var = variable, | |
resp = cont | |
) %>% | |
select(var, resp, est) | |
} else { | |
out[, cont] %>% | |
mutate(var = variable) %>% | |
melt(id = "var") %>% | |
rename(resp = variable, est = value) | |
} | |
} | |
# Transform mtcars data | |
mtcars <- | |
mtcars %>% | |
mutate( | |
cyl = | |
cyl %>% | |
as.factor() | |
) | |
# Fit frequentist model to mtcars | |
freq <- lm(mpg ~ 1 + cyl + wt, data = mtcars) | |
# Fit Bayesian model | |
bayes <- brm(formula = mpg ~ 1 + cyl + wt, | |
family = gaussian(), | |
data = mtcars, | |
chains = 2, | |
cores = 2) | |
# Compute Bayesian AMEs | |
wt_ame <- bayes_dydx.default(bayes, variable = "wt") | |
cyl_ame <- bayes_dydx.factor(bayes, variable = "cyl") | |
# Get summary stats and compare to frequentist AMEs | |
wt_ame$est %>% quantile(probs = c(.5, .025, .975)) | |
cyl_ame$est[cyl_ame$resp == "6"] %>% quantile(probs = c(.5, .025, .975)) | |
cyl_ame$est[cyl_ame$resp == "8"] %>% quantile(probs = c(.5, .025, .975)) | |
summary(margins(freq)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
This is an incredible resource. I tried to adapt the bayes_dydx.factor for an ordinal model. It seems to work great, however, it's very hard to compare with the margins package which does not provide the Average Marginal Effect by outcome. Do you know how to use the margins command in R for ordinal models at specific outcome levels? I was hoping to have a comparison to ensure that the output matches.
I was trying to recreate something like what is discussed on this STATA page: https://www.statalist.org/forums/forum/general-stata-discussion/general/1465336-ordered-probit-marginal-effects, but could not
Again, thank you so much for the code above. It's so useful and easy to follow along. I included an example below from https://stats.idre.ucla.edu/r/dae/ordinal-logistic-regression/.
`
library(brms)
library(tidybayes)
library(tidyverse)
library(magrittr)
library(reshape2)
library(margins)
Example:
`