Skip to content

Instantly share code, notes, and snippets.

@nanxstats
Last active December 23, 2020 02:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nanxstats/4b64f81aa258959bef6ca06572307298 to your computer and use it in GitHub Desktop.
Save nanxstats/4b64f81aa258959bef6ca06572307298 to your computer and use it in GitHub Desktop.
Bayesian lasso with greta, compared to lasso and msaenet
# https://nanx.me/blog/post/bayesian-lasso-with-greta/
# generate synthetic data ------------------------------------------------------
library("msaenet")
n <- 500
p <- 1000
pnz <- 10
dat <- msaenet.sim.gaussian(
n = n * 2, p = p,
rho = 0.5, coef = rep(5, pnz), snr = 3,
p.train = 0.5, seed = 42
)
x <- dat$x.tr
y <- dat$y.tr
beta <- c(rep(5, pnz), rep(0, p - pnz))
# msaenet ----------------------------------------------------------------------
library("doParallel")
registerDoParallel(detectCores())
fit_msaenet <- msaenet(
x, y,
family = "gaussian",
init = "ridge", alphas = seq(0.05, 0.95, 0.05),
tune = "cv", nfolds = 10, rule = "lambda.min",
nsteps = 20, tune.nsteps = "ebic",
seed = 42, parallel = TRUE, verbose = FALSE
)
msaenet.nzv(fit_msaenet)
png("greta-msaenet-coef.png", res = 300, height = 1500, width = 2100)
par()
opar <- par()
par(mar = c(2, 4, 1, 2) + 0.1)
layout(matrix(c(1, 2), 2, 1), heights = c(2, 1))
plot(fit_msaenet, type = "coef")
plot(fit_msaenet, type = "criterion", ylab = "EBIC")
par(opar)
dev.off()
(tp_msaenet <- msaenet.tp(fit_msaenet, 1:pnz))
(fp_msaenet <- msaenet.fp(fit_msaenet, 1:pnz))
pred_msaenet <- predict(fit_msaenet, dat$x.te)
(mse_msaenet <- msaenet.mse(dat$y.te, pred_msaenet))
# lasso ------------------------------------------------------------------------
library("glmnet")
set.seed(42)
cv_lasso <- cv.glmnet(x, y, family = "gaussian", alpha = 1, nfolds = 10)
png("greta-glmnet-cv.png", res = 300, height = 1500, width = 2100)
plot(cv_lasso)
dev.off()
fit_lasso <- glmnet(x, y, family = "gaussian", alpha = 1, lambda = cv_lasso$lambda.min)
selected_lasso <- (abs(as.vector(fit_lasso$beta)) > .Machine$double.eps)
(tp_lasso <- sum((selected_lasso & beta)[1:pnz]))
(fp_lasso <- sum((selected_lasso | beta)[-(1:pnz)]))
pred_lasso <- predict(fit_lasso, dat$x.te)
(mse_lasso <- msaenet.mse(dat$y.te, pred_lasso))
# bayesian lasso ---------------------------------------------------------------
library("greta")
set.seed(42)
# define data model
intercept <- normal(0, 10)
sd <- cauchy(0, 3, truncation = c(0, Inf))
coefs <- laplace(0, 1, dim = ncol(x))
mu <- intercept + x %*% coefs
distribution(y) <- normal(mu, sd)
m <- model(intercept, coefs, sd)
plot(m)
draws_blasso <- mcmc(m, warmup = 1000, n_samples = 5000, chains = 8)
# utility functions for posterior ----------------------------------------------
# get beta posterior estimate from a chain
get_betahat <- function(df) {
betahat <- apply(df[, 2:(ncol(df) - 1)], 2, mean, na.rm = TRUE)
names(betahat) <- NULL
betahat
}
# get intercept posterior estimate from a chain
get_intercept <- function(df) {
alphahat <- mean(df[, 1], na.rm = TRUE)
names(alphahat) <- NULL
alphahat
}
# get credible interval from a chain
get_ci <- function(df, prob) {
ci <- apply(df[, 2:(ncol(df) - 1)], 2, quantile, probs = prob, na.rm = TRUE)
names(ci) <- NULL
ci
}
# variable selection - checks whether 0 is contained in the credible interval
# ported from horseshoe::HS.var.select(method = "intervals")
threshold <- function(lower_ci, upper_ci) {
as.numeric(1 - ((lower_ci <= 0) & (upper_ci >= 0)))
}
# get MSE from each chain
mse_chain <- function(draws, x, y) {
k <- length(draws)
mse <- rep(NA, k)
for (i in 1:k) {
chain <- draws[[i]]
post_mean <- get_betahat(chain)
lower_ci <- get_ci(chain, 0.025)
upper_ci <- get_ci(chain, 0.975)
beta <- threshold(lower_ci, upper_ci) * post_mean
alpha <- get_intercept(chain)
pred <- x %*% as.matrix(beta) + alpha
mse[i] <- msaenet.mse(y, pred)
}
mse
}
# select the chain with the minimal MSE on training set
idx_chain <- which.min(mse_chain(draws_blasso, dat$x.tr, dat$y.tr))
chain_blasso <- draws_blasso[[idx_chain]]
# create data frame for plotting
df_blasso <- data.frame(
index = 1:ncol(x),
truth = beta,
post_mean = get_betahat(chain_blasso),
lower_ci = get_ci(chain_blasso, 0.025),
upper_ci = get_ci(chain_blasso, 0.975)
)
df_blasso$selected <- threshold(df_blasso$lower_ci, df_blasso$upper_ci)
library("ggplot2")
library("ggsci")
ggplot(data = df_blasso, aes(x = index, y = truth)) +
geom_point(size = 2) +
theme_classic(base_size = 24) +
ylab("") +
geom_point(aes(x = index, y = post_mean, col = factor(selected)), size = 2) +
geom_errorbar(aes(ymin = lower_ci, ymax = upper_ci, col = factor(selected)), width = 0.1) +
theme(legend.position = "none") +
scale_color_aaas() +
ggtitle("black = truth, red = selected, blue = not selected")
ggsave("greta-bayesian-lasso-coef.png", dpi = 300, width = 24, height = 12)
(tp_blasso <- sum((threshold(df_blasso$lower_ci, df_blasso$upper_ci) & beta)[1:pnz]))
(fp_blasso <- sum((threshold(df_blasso$lower_ci, df_blasso$upper_ci) | beta)[-(1:pnz)]))
beta_blasso <- threshold(df_blasso$lower_ci, df_blasso$upper_ci) * df_blasso$post_mean
pred_blasso <- dat$x.te %*% as.matrix(beta_blasso) + get_intercept(chain_blasso)
(mse_blasso <- msaenet.mse(dat$y.te, pred_blasso))
# summary table ----------------------------------------------------------------
tbl <- data.frame(
"Method" = c("msaenet", "Lasso", "Bayesian Lasso"),
"TP" = c(tp_msaenet, tp_lasso, tp_blasso),
"FP" = c(fp_msaenet, fp_lasso, fp_blasso),
"MSE" = c(mse_msaenet, mse_lasso, mse_blasso)
)
knitr::kable(tbl, digits = 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment