Skip to content

Instantly share code, notes, and snippets.

@gongcastro
Created November 10, 2022 11:53
Show Gist options
  • Save gongcastro/0450bb0058282ad23229997b1def4aad to your computer and use it in GitHub Desktop.
Save gongcastro/0450bb0058282ad23229997b1def4aad to your computer and use it in GitHub Desktop.
ROC curves for multinomial and binomial Bayesian models in brms
library(dplyr) # for data wrangling
library(tidyr) # same
library(purrr) # for functional programming
library(rlang) # for tidyeval
library(ggplot2) # for dataviz
library(ggsci) # for nice colours
library(scales) # for displaying percentages
library(brms) # for Bayesian models
library(tidybayes) # for extracting posterior draws and predictions
library(yardstick) # for generating ROC curves
# you might need to install cmdstanr too
# set options ------------------------------------------------------------------
options(mc.cores = 4, brms.backend = "cmdstanr") # for faster compilation and sampling
set.seed(888) # for reproducibility
theme_set(theme_ggdist()) # change ggplot theme
# create functions -------------------------------------------------------------
# generate mean posterior predictions and ROC values
get_roc_curve <- function(newdata, object, ...) {
# enquote response variable and get brmsfit family
resp_var <- formula(object)[["formula"]][[2]]
resp_var <- enquo(resp_var)
model_fam <- object[["family"]][["family"]]
# object must be a brmsfit object with a supported family
supported <- c("bernoulli", "binomial", "categorical", "cumulative", "sratio", "cratio", "acat")
stopifnot(is.brmsfit(object))
if (!(model_fam %in% supported)) stop(paste0("model family must be one of: ", paste0(supported, collapse = ", ")))
if (model_fam %in% c("binomial", "bernoulli")) {
roc_values <- add_epred_draws(newdata, object, ...) %>%
ungroup() %>%
mutate(!!resp_var := as.factor(!!resp_var)) %>%
# generate a ROC curve for each posterior draw
split(.$.draw) %>%
map_dfr(~roc_curve(., truth = !!resp_var, .epred, event_level = "second"), .id = ".draw")
} else {
cat_symbols <- syms(as.character(unique(get_y(object))))
roc_values <- add_epred_draws(newdata, object, ...) %>%
ungroup() %>%
mutate(!!resp_var := as.factor(!!resp_var)) %>%
# spread predictions for different categories across different columns
pivot_wider(names_from = .category, values_from = .epred) %>%
# generate a ROC curve for each posterior draw
split(.$.draw) %>%
map_dfr(~roc_curve(., truth = !!resp_var, !!!cat_symbols), .id = ".draw")
}
return(roc_values)
}
# single ROC -------------------------------------------------------------------
# fit cumulative(logit) model
fit <- brm(
rating ~ treat + period + (1 | subject),
data = brms::inhaler,
family = cumulative("logit"),
chains = 4
)
roc <- get_roc_curve(brms::inhaler, fit, ndraws = 50)
ggplot(roc, aes(1-specificity, sensitivity, colour = .level)) +
facet_wrap(~.level, labeller = labeller(.level = ~paste("Category", .))) +
geom_line(aes(y = 1-specificity), linetype = "dotted", colour = "black") +
geom_line(aes(group = interaction(.draw)), alpha = 0.1) +
stat_summary(fun = mean, geom = "line", size = 1) + # mean posterior prediction
scale_colour_d3() +
scale_x_continuous(labels = percent) +
scale_y_continuous(labels = percent) +
coord_equal() +
labs(x = "1- Specificity", y = "Sensibility", colour = "Model") +
theme_ggdist() +
theme(legend.position = "none")
ggsave("img/rocs-single.png", height = 7, width = 9, dpi = 800)
# multinomial ROC --------------------------------------------------------------
# helper function for getting brms model family
get_family <- function(x) paste0(x[["family"]][["family"]], "(", x[["family"]][["link"]], ")")
# wrapper for fitting models
fit_model <- function(...) brm(rating ~ treat + period + (1 | subject), chains = 4, ...)
# fit multinomial models
multinomial_fits <- list(cumulative("logit"), sratio("logit"), cratio("logit"), categorical("logit")) %>%
map(fit_model, data = brms::inhaler) %>%
set_names(map_chr(., get_family))
roc_multinomial <- multinomial_fits %>%
map(~get_roc_curve(brms::inhaler, ., ndraws = 50)) %>%
bind_rows(.id = "model")
roc_multinomial %>%
ggplot(aes(1-specificity, sensitivity, colour = .level)) +
facet_grid(model~.level, labeller = labeller(.level = ~paste("Category", .))) +
geom_line(aes(y = 1-specificity), linetype = "dotted", colour = "black") +
geom_line(aes(group = interaction(model, .draw)), alpha = 0.1) +
stat_summary(fun = mean, geom = "line", size = 1) + # mean posterior prediction
scale_colour_d3() +
scale_x_continuous(labels = percent) +
scale_y_continuous(labels = percent) +
coord_equal() +
labs(x = "1- Specificity", y = "Sensibility", colour = "Model") +
theme_ggdist() +
theme(
axis.text = element_text(size = 7),
legend.position = "none",
)
ggsave("rocs-multinomial.png", height = 7, width = 9, dpi = 800)
# binomial ROC -----------------------------------------------------------------
# fit binomial models
binomial_fits <- list(bernoulli("logit"), bernoulli("probit"), bernoulli("cloglog"), bernoulli("cauchit")) %>%
# fit binomial models on dicotomised rating (TRUE if rating==1)
map(fit_model, data = mutate(brms::inhaler, rating = as.integer(rating==1))) %>%
set_names(map_chr(., get_family)) # name list elements with their model family
roc_binomial <- roc_multinomial <- binomial_fits %>%
map(~get_roc_curve(mutate(brms::inhaler, rating = as.integer(rating==1)), ., ndraws = 50)) %>%
bind_rows(.id = "model")
roc_binomial %>%
ggplot(aes(1-specificity, sensitivity, colour = model)) +
facet_wrap(~model, labeller = labeller(.level = ~paste("Category", .))) +
geom_line(aes(y = 1-specificity), linetype = "dotted", colour = "black") +
geom_line(aes(group = interaction(model, .draw)), alpha = 0.1) +
stat_summary(fun = mean, geom = "line", size = 1) +
scale_colour_d3() +
scale_x_continuous(labels = percent) +
scale_y_continuous(labels = percent) +
coord_equal() +
labs(x = "1- Specificity", y = "Sensibility", colour = "Category") +
theme_ggdist() +
theme(
axis.text = element_text(size = 9)
)
ggsave("rocs-binomial.png", height = 7, width = 9, dpi = 800)
@fusaroli
Copy link

Am I reading it wrong, or are you just focusing on the mean predicted value (epred)? To fully capture the uncertainty in predictions I usually generate the full _linpred and then transform to binary decisions each sample (which is more cumbersome, but gives the full uncertainty)

@gongcastro
Copy link
Author

Continuing the convo in Mastodon: 🐘 https://fediscience.org/@gongcastro/109353350340945974.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment