Skip to content

Instantly share code, notes, and snippets.

@ryanholbrook
Created October 3, 2019 08:18
Show Gist options
  • Save ryanholbrook/b5c7d44c0c7642eeee1a3034b48f29d7 to your computer and use it in GitHub Desktop.
Save ryanholbrook/b5c7d44c0c7642eeee1a3034b48f29d7 to your computer and use it in GitHub Desktop.
Convex Dirichlet Aggregation with brms and Parsnip
library(brms)
library(parsnip)
## Fits a brms model
convex_regression <- function(formula, data,
family = "gaussian",
alpha = 1, gamma = 2, # Yang (2014) recommends alpha=1, gamma=2
verbose = 0,
...) {
if (gamma <= 1) {
warning(paste("Parameter gamma should be greater than 1. Given:", gamma))
}
if (alpha <= 0) {
warning(paste("Parameter alpha should be greater than 0. Given:", alpha))
}
## Set up priors.
K <- length(terms(formula))
alpha_K <- alpha / (K^gamma)
stanvars <-
stanvar(alpha_K,
"alpha_K",
block = "data",
scode = " real<lower = 0> alpha_K; // dirichlet parameter"
) +
stanvar(
name = "b_raw",
block = "parameters",
scode = " vector<lower = 0>[K] b_raw; "
) +
stanvar(
name = "b",
block = "tparameters",
scode = " vector[K] b = b_raw / sum(b_raw);"
)
prior <- prior("target += gamma_lpdf(b_raw | alpha_K, 1)",
class = "b_raw", check = FALSE
)
f <- update.formula(formula, . ~ . - 1)
if (verbose > 0) {
make_stancode(f,
prior = prior,
data = data,
stanvars = stanvars
) %>% message()
}
fit_dir <- brm(f,
prior = prior,
family = family,
data = data,
stanvars = stanvars,
...
)
fit_dir
}
## Parsnip Definition
set_new_model("convex_reg")
set_model_mode(model = "convex_reg", mode = "regression")
set_model_engine(
"convex_reg",
mode = "regression",
eng = "brms"
)
set_dependency("convex_reg", "brms", "brms")
set_model_arg(
model = "convex_reg",
eng = "brms",
parsnip = "scale",
original = "alpha",
func = list(fun = "alpha"),
has_submodel = FALSE
)
alpha <- new_quant_param(
type = "double",
range = c(1, Inf),
inclusive = c(TRUE, FALSE),
default = 1,
label = c(scale = "scale")
)
set_model_arg(
model = "convex_reg",
eng = "brms",
parsnip = "penalty",
original = "gamma",
func = list(fun = "gamma"),
has_submodel = FALSE
)
gamma <- new_quant_param(
type = "double",
range = c(1, Inf),
inclusive = c(TRUE, FALSE),
default = 2,
label = c(penalty = "penalty")
)
convex_reg <- function(mode = "regression", scale = 1, penalty = 2) {
## Check for correct mode
if (mode != "regression") {
stop("`mode` should be 'regression'", call. = FALSE)
}
## Capture the arguments in quosures
args <- list(
scale = rlang::enquo(scale),
penalty = rlang::enquo(penalty)
)
## Save some empty slots for future parts of the specification
out <- list(
args = args, eng_args = NULL,
mode = mode, method = NULL, engine = NULL
)
## set classes in the correct order
class(out) <- make_classes("convex_reg")
out
}
set_fit(
model = "convex_reg",
eng = "brms",
mode = "regression",
value = list(
interface = "formula",
protect = c("formula", "data"),
func = c(fun = "convex_regression"),
defaults = list()
)
)
num_info <-
pred_value_template(
pre = NULL,
post = function(results, object) {
results %>%
as_tibble() %>%
rename(.pred = Estimate) %>%
pull(.pred)
},
func = c(fun = "predict"),
object = quote(object$fit),
newdata = quote(new_data),
type = "response"
)
set_pred(
model = "convex_reg",
eng = "brms",
mode = "regression",
type = "numeric",
value = num_info
)
raw_info <-
pred_value_template(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
object = quote(object$fit),
newdata = quote(new_data),
type = "response"
)
set_pred(
model = "convex_reg",
eng = "brms",
mode = "regression",
type = "raw",
value = raw_info
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment