Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jackobailey/0d29d2f2229cfe9d83c1672488308882 to your computer and use it in GitHub Desktop.
Save jackobailey/0d29d2f2229cfe9d83c1672488308882 to your computer and use it in GitHub Desktop.
Predictive margins for categorical models
bayes_margin.factor_mn <- function(model, variable = NULL, data = NULL, draws = NULL, n = NULL, re_formula = NA){
# Check that everything is running properly and that the
# user has provided all of the relevant information.
if(is.null(model) == T){
stop("Please provide a model to the function using the 'model =' argument (e.g. model = m1)")
} else if(is.null(variable) == T){
stop("Please provide a variable name to compute average marginal effects for using the 'variable =' argument (e.g. variable = 'x'")
# If the user hasn't provided their own data to use
# we use the data included in the brms object itself.
if(is.null(data) == T){
d <- model$data
} else {
d <- data
# If the user has specified a non-null value for n
# we use this as the number of random samples the
# user wants us to take from their data.
if(is.null(n) == F){
d <- d %>% sample_n(n)
# Next, we get the name of the outcome variable from
# the brms object...
resp <- model$formula$resp
# ... then we omit it from the data
d <- d %>% select(-resp)
# Next, if the user has leave the re_formula argument
# as NA, we remove random effects for the model if it
# has them. Otherwise, we include them.
if( == T & nrow(model$ranef) != 0){
# Get random effects
rnfx <- unique(model$ranef$group)
# Omit from data
d <-
d %>%
# Calculate observed combinations and frequencies
# to reduce computation time where n is large
d <-
d %>%
group_by_all() %>%
count(name = "w") %>%
# Now we get all of the levels of the factor that
# we are using.
levs <- levels(as.factor(d[[variable]]))
base <- levs[1L]
cont <- levs[-1L]
# We need somewhere to put all of the fitted draws
# that we're about to compute so we'll create an
# empty list.
f <- list()
# Then we'll loop over each of the factor levels and
# compute the fitted draws.
for (i in seq_along(levs)){
# First, we fix all cases of our variable of interest
# to the index factor level.
d[[variable]] <- levs[i]
# Second, we compute the fitted draws, weight them by
# the number of cases, and then summarise the effect
f[[i]] <-
d %>%
model = model,
n = draws,
re_formula = re_formula,
value = "eff"
) %>%
) %>%
summarise(eff_w = sum(eff * w)/sum(w)) %>%
ungroup() %>%
select(eff_w, .category)
# Finally, we rename the variables in the each
# list using the level name and "resp".
names(f[[i]]) <- c(levs[i], "resp")
# Now we can combine the individual datasets in
# the list into a single dataset.
out <-, f)
# Then we convert them to long format to make
# them a little easier to deal with.
out <-
out %>%
cols = {{levs}},
names_to = "fct",
values_to = "ame"
) %>%
fct = fct %>% factor(levels = levs)
# Finally, we return the AMEs to the user
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment