Skip to content

Instantly share code, notes, and snippets.

@rubenarslan
Created January 25, 2016 12:16
Show Gist options
  • Save rubenarslan/80ecc20be0e0600b41dc to your computer and use it in GitHub Desktop.
Save rubenarslan/80ecc20be0e0600b41dc to your computer and use it in GitHub Desktop.
#' easier marginal effect plots from brms objects
#' ## ideas?
#' visualise uncertainty with violin plots instead of pointranges
#' (would mean getting rid of early-on summary)
#' ### shorthand for finding mode
#' from https://stackoverflow.com/questions/2547402/standard-library-function-in-r-for-finding-the-mode
Mode <- function(x) {
ux <- unique(x)
ux[which.max(tabulate(match(x, ux)))]
}
#' ### Make newdata for marginal effects plots
predictor = "fertile"
make_newdata = function(data, predictor, ...) {
# maybe do interpolation by default, if plots wouldn't look smooth without?
newdata = na.omit(unique(data[, predictor, drop = F]))
left = setdiff(names(data), names(newdata))
for (i in seq_along(left)) {
varname = left[i]
if (is.numeric(data[, left[i]])) { ## choose mean for numeric covariates
newdata[, varname] = mean(data[, left[i]],na.rm = TRUE)
} else if (is.logical((data[, left[i]]))) { ## choose FALSE (reference) for bools
newdata[, varname] = FALSE
} else if (is.factor(data[ left[i] ])) { ## choose reference for factors
fact = unique(data[,left[i] ])
newdata[, varname] = fact[fact == levels(fact)[1]]
} else {
## choose the Mode for everything else
newdata[, varname] = Mode(data[,left[i] ])
}
}
newdata = newdata[order(newdata[, predictor]),]
rownames(newdata) = 1:nrow(newdata)
newdata
}
#' ### derive a newdata object for each predictor and each interaction with marginal effects
allEffects.brmsfit = function(fit, rug = F, predictors = NULL, re_formula = NA, ...) {
data = fit$data
# here I'm doing some ugly stuff to get from the model call to the fixed effects that should be in the marginal effects/new data object. could probably be cleaner with some understanding of brms internal methods for this
# does not yet work for ranefs that shouldn't be marginalised over
fixef_vars = delete.response(terms(lme4::nobars(fit$formula)))
take = setdiff(all.vars(fixef_vars), c("trait","spec","main"))
predictor_data = data[, take]
available_predictors = attributes(terms(fixef_vars))$term.labels
if (!is.null(predictors)) {
predictors = as.list(intersect(predictors, available_predictors))
} else {
predictors = available_predictors
}
predictors = sapply(predictors, stringr::str_split, pattern = stringr::fixed(":"))
effect_list = list()
for (i in 1:length(predictors)) {
predictor = predictors[[i]]
predictor = predictor[ !predictor %in% c("trait","spec","main")]
if (length(predictor) > 0) {
newd = make_newdata(predictor_data, predictor)
if (brms:::is.ordinal(fit$family)) { ## using an internal method from brms..
# I want a different summary here (not probabilities by level)
prediction = fitted(fit, newdata = newd, re_formula = re_formula, summary = F)
# take the ordinal levels from the response var
multiply_by_level = rep(sort(unique(fit$data[,1])), each = dim(prediction)[1] * dim(prediction)[2])
dim(multiply_by_level) = dim(prediction)
# multiply each category by its level
prediction = prediction * multiply_by_level
## add weighted levels up
response = apply(prediction, MARGIN = c(1:2), FUN = sum)
effect_data = newd
## now get summary
effect_data$Estimate = apply(response, MARGIN = 2, FUN = mean)
intervals = apply(response, MARGIN = 2, FUN = function(x) { quantile(x, probs = c(0.025,0.975), na.rm = TRUE) })
effect_data[, "2.5%ile"] = unlist(intervals[1, ])
effect_data[, "97.5%ile"] = unlist(intervals[2, ])
} else {
## simpler for other families I've worked with so far
prediction = fitted(fit, newdata = newd, re_formula = re_formula, summary = T, scale = "response")
effect_data = cbind(newd, prediction)
}
attributes(effect_data)$outcome = fit$formula[2] %>% as.character()
attributes(effect_data)$predictors = predictor
if (rug) {
attributes(effect_data)$rug = data[, predictor, drop = F]
}
effect_list[[paste0(predictor, collapse = ":")]] = effect_data
}
}
effect_list
}
#' plot the object generated above
plot.efflist.brmsfit = function(effect_list, printPlot = T) {
library(ggplot2)
if (is.data.frame(effect_list)) {
name = deparse(substitute(effect_list))
effect_data = effect_list
effect_list = list()
effect_list[[name]] = effect_data
}
plot_list = list()
for (i in 1:length(effect_list)) {
effect_data = effect_list[[i]]
outcome = attributes(effect_data)$outcome
predictor = attributes(effect_data)$predictors
if (length(predictor) == 1) {
plot_list[[ i ]] = ggplot(data = effect_data, aes_string(x = predictor, y = "Estimate", ymin = "`2.5%ile`", ymax = "`97.5%ile`")) + ylab(outcome)
if (is.numeric(effect_data[,predictor])) {
plot_list[[ i ]] = plot_list[[ i ]] + geom_smooth(stat = 'identity')
if (!is.null(attributes(effect_data)$rug)) {
plot_list[[ i ]] = plot_list[[ i ]] + geom_rug(aes_string(x = predictor), sides = "b", data = attributes(effect_list)$rug, inherit.aes = F)
}
}
else {
plot_list[[ i ]] = plot_list[[ i ]] + geom_pointrange(stat = 'identity')
}
} else if (length(predictor) == 2) {
plot_list[[ i ]] = ggplot(data = effect_data, aes_string(x = predictor[2], colour = predictor[1], fill = predictor[1], y = "Estimate", ymin = "`2.5%ile`", ymax = "`97.5%ile`")) + ylab(outcome)
if (is.numeric(effect_data[,predictor[2]])) {
plot_list[[ i ]] = plot_list[[ i ]] + geom_smooth(stat = 'identity')
if (!is.null(attributes(effect_data)$rug)) {
plot_list[[ i ]] = plot_list[[ i ]] + geom_rug(aes_string(x = predictor[2]), sides = "b", data = attributes(effect_list)$rug, inherit.aes = F)
}
}
else {
plot_list[[ i ]] = plot_list[[ i ]] + geom_pointrange(stat = 'identity')
}
}
if (printPlot) {
print(plot_list[[i]])
}
}
invisible(plot_list)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment