Created
January 25, 2016 12:16
-
-
Save rubenarslan/80ecc20be0e0600b41dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' easier marginal effect plots from brms objects | |
#' ## ideas? | |
#' visualise uncertainty with violin plots instead of pointranges | |
#' (would mean getting rid of early-on summary) | |
#' ### shorthand for finding mode | |
#' from https://stackoverflow.com/questions/2547402/standard-library-function-in-r-for-finding-the-mode | |
Mode <- function(x) { | |
ux <- unique(x) | |
ux[which.max(tabulate(match(x, ux)))] | |
} | |
#' ### Make newdata for marginal effects plots | |
predictor = "fertile" | |
make_newdata = function(data, predictor, ...) { | |
# maybe do interpolation by default, if plots wouldn't look smooth without? | |
newdata = na.omit(unique(data[, predictor, drop = F])) | |
left = setdiff(names(data), names(newdata)) | |
for (i in seq_along(left)) { | |
varname = left[i] | |
if (is.numeric(data[, left[i]])) { ## choose mean for numeric covariates | |
newdata[, varname] = mean(data[, left[i]],na.rm = TRUE) | |
} else if (is.logical((data[, left[i]]))) { ## choose FALSE (reference) for bools | |
newdata[, varname] = FALSE | |
} else if (is.factor(data[ left[i] ])) { ## choose reference for factors | |
fact = unique(data[,left[i] ]) | |
newdata[, varname] = fact[fact == levels(fact)[1]] | |
} else { | |
## choose the Mode for everything else | |
newdata[, varname] = Mode(data[,left[i] ]) | |
} | |
} | |
newdata = newdata[order(newdata[, predictor]),] | |
rownames(newdata) = 1:nrow(newdata) | |
newdata | |
} | |
#' ### derive a newdata object for each predictor and each interaction with marginal effects | |
allEffects.brmsfit = function(fit, rug = F, predictors = NULL, re_formula = NA, ...) { | |
data = fit$data | |
# here I'm doing some ugly stuff to get from the model call to the fixed effects that should be in the marginal effects/new data object. could probably be cleaner with some understanding of brms internal methods for this | |
# does not yet work for ranefs that shouldn't be marginalised over | |
fixef_vars = delete.response(terms(lme4::nobars(fit$formula))) | |
take = setdiff(all.vars(fixef_vars), c("trait","spec","main")) | |
predictor_data = data[, take] | |
available_predictors = attributes(terms(fixef_vars))$term.labels | |
if (!is.null(predictors)) { | |
predictors = as.list(intersect(predictors, available_predictors)) | |
} else { | |
predictors = available_predictors | |
} | |
predictors = sapply(predictors, stringr::str_split, pattern = stringr::fixed(":")) | |
effect_list = list() | |
for (i in 1:length(predictors)) { | |
predictor = predictors[[i]] | |
predictor = predictor[ !predictor %in% c("trait","spec","main")] | |
if (length(predictor) > 0) { | |
newd = make_newdata(predictor_data, predictor) | |
if (brms:::is.ordinal(fit$family)) { ## using an internal method from brms.. | |
# I want a different summary here (not probabilities by level) | |
prediction = fitted(fit, newdata = newd, re_formula = re_formula, summary = F) | |
# take the ordinal levels from the response var | |
multiply_by_level = rep(sort(unique(fit$data[,1])), each = dim(prediction)[1] * dim(prediction)[2]) | |
dim(multiply_by_level) = dim(prediction) | |
# multiply each category by its level | |
prediction = prediction * multiply_by_level | |
## add weighted levels up | |
response = apply(prediction, MARGIN = c(1:2), FUN = sum) | |
effect_data = newd | |
## now get summary | |
effect_data$Estimate = apply(response, MARGIN = 2, FUN = mean) | |
intervals = apply(response, MARGIN = 2, FUN = function(x) { quantile(x, probs = c(0.025,0.975), na.rm = TRUE) }) | |
effect_data[, "2.5%ile"] = unlist(intervals[1, ]) | |
effect_data[, "97.5%ile"] = unlist(intervals[2, ]) | |
} else { | |
## simpler for other families I've worked with so far | |
prediction = fitted(fit, newdata = newd, re_formula = re_formula, summary = T, scale = "response") | |
effect_data = cbind(newd, prediction) | |
} | |
attributes(effect_data)$outcome = fit$formula[2] %>% as.character() | |
attributes(effect_data)$predictors = predictor | |
if (rug) { | |
attributes(effect_data)$rug = data[, predictor, drop = F] | |
} | |
effect_list[[paste0(predictor, collapse = ":")]] = effect_data | |
} | |
} | |
effect_list | |
} | |
#' plot the object generated above | |
plot.efflist.brmsfit = function(effect_list, printPlot = T) { | |
library(ggplot2) | |
if (is.data.frame(effect_list)) { | |
name = deparse(substitute(effect_list)) | |
effect_data = effect_list | |
effect_list = list() | |
effect_list[[name]] = effect_data | |
} | |
plot_list = list() | |
for (i in 1:length(effect_list)) { | |
effect_data = effect_list[[i]] | |
outcome = attributes(effect_data)$outcome | |
predictor = attributes(effect_data)$predictors | |
if (length(predictor) == 1) { | |
plot_list[[ i ]] = ggplot(data = effect_data, aes_string(x = predictor, y = "Estimate", ymin = "`2.5%ile`", ymax = "`97.5%ile`")) + ylab(outcome) | |
if (is.numeric(effect_data[,predictor])) { | |
plot_list[[ i ]] = plot_list[[ i ]] + geom_smooth(stat = 'identity') | |
if (!is.null(attributes(effect_data)$rug)) { | |
plot_list[[ i ]] = plot_list[[ i ]] + geom_rug(aes_string(x = predictor), sides = "b", data = attributes(effect_list)$rug, inherit.aes = F) | |
} | |
} | |
else { | |
plot_list[[ i ]] = plot_list[[ i ]] + geom_pointrange(stat = 'identity') | |
} | |
} else if (length(predictor) == 2) { | |
plot_list[[ i ]] = ggplot(data = effect_data, aes_string(x = predictor[2], colour = predictor[1], fill = predictor[1], y = "Estimate", ymin = "`2.5%ile`", ymax = "`97.5%ile`")) + ylab(outcome) | |
if (is.numeric(effect_data[,predictor[2]])) { | |
plot_list[[ i ]] = plot_list[[ i ]] + geom_smooth(stat = 'identity') | |
if (!is.null(attributes(effect_data)$rug)) { | |
plot_list[[ i ]] = plot_list[[ i ]] + geom_rug(aes_string(x = predictor[2]), sides = "b", data = attributes(effect_list)$rug, inherit.aes = F) | |
} | |
} | |
else { | |
plot_list[[ i ]] = plot_list[[ i ]] + geom_pointrange(stat = 'identity') | |
} | |
} | |
if (printPlot) { | |
print(plot_list[[i]]) | |
} | |
} | |
invisible(plot_list) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment