Created
October 25, 2023 13:57
-
-
Save oliviergimenez/cb281d3ba593c9b55d0c4c9343ac6a72 to your computer and use it in GitHub Desktop.
Lasso for logistic regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Implements Lasso for logistic regression, both classical/bayesian ways | |
## 1. SIMULATION | |
# for reproducibility | |
set.seed(666) | |
# sample size | |
n <- 100 | |
# explanatory variables | |
x1 <- rnorm(n) | |
x2 <- rnorm(n) | |
x3 <- rnorm(n) | |
# regression parameters (X1 and X3 are significant, X2 is not) | |
beta <- c(0.5, 1, 0, 1) # beta0, beta1, beta2, beta3 | |
# prob of success | |
logit_prob <- beta[1] + beta[2] * x1 + beta[3] * x2 + beta[4] * x3 | |
prob <- plogis(logit_prob) | |
# response var | |
y <- rbinom(n, 1, prob) | |
## 2. AJUSTEMENT | |
# classical framework, no lasso | |
summary(glm(y~x1+x2+x3, family = "binomial")) | |
# classical framework, lasso | |
library(glmnet) | |
x.mat <- cbind(x1,x2,x3) | |
cv.fit <- cv.glmnet(x.mat, y, family = "binomial", type.measure = "class") | |
plot(cv.fit) | |
cv.fit$lambda.min | |
coef(cv.fit, s = "lambda.min") | |
#fit <- glmnet(x.mat, y) | |
#plot(fit) | |
# bayesian framework, no lasso | |
library(R2jags) | |
# model | |
model <- function() { | |
for (i in 1:n){ | |
logit(p[i]) <- beta[1] + beta[2] * X[i,1] + beta[3] * X[i,2] + beta[4] * X[i,3] | |
y[i] ~ dbern(p[i]) | |
} | |
for(j in 1:4){ | |
beta[j] ~ dnorm(0,0.1) # jags uses precision = 1 / var | |
} | |
} | |
# data | |
data <- list(y = y , X = cbind(x1, x2, x3), n = n) | |
# initial values (2 chains) | |
init <- list(list(beta = rnorm(4)), | |
list(beta = rnorm(4))) | |
# fit model | |
out <- jags(data = data, | |
inits = init, | |
parameters.to.save = c("beta"), | |
model.file = model, | |
n.chains = 2, | |
n.iter = 6000, | |
n.burnin = 1000, | |
n.thin = 1) | |
# print results | |
print(out) | |
# check convergence | |
traceplot(out, ask=F) | |
# nice plots | |
library(lattice) | |
jagsfit.mcmc <- as.mcmc(out) | |
densityplot(jagsfit.mcmc) | |
# bayesian framework, lasso | |
library(R2jags) | |
# model (https://stats.stackexchange.com/questions/28609/regularized-bayesian-logistic-regression-in-jags) | |
model <- function() { | |
for (i in 1:n){ | |
logit(p[i]) <- beta[1] + beta[2] * X[i,1] + beta[3] * X[i,2] + beta[4] * X[i,3] | |
y[i] ~ dbern(p[i]) | |
} | |
beta[1] ~ dnorm(0,0.1) # jags uses precision = 1 / var | |
# L1 regularization == a Laplace (double exponential) prior | |
for (j in 2:4) { | |
beta[j] ~ ddexp(0, lambda) | |
} | |
lambda ~ dunif(0.001,10) | |
} | |
# data | |
data <- list(y = y , X = cbind(x1, x2, x3), n = n) | |
# initial values (2 chains) | |
init <- list(list(beta = rnorm(4)), | |
list(beta = rnorm(4))) | |
# fit model | |
out <- jags(data = data, | |
inits = init, | |
parameters.to.save = c("beta", "lambda"), | |
model.file = model, | |
n.chains = 2, | |
n.iter = 6000, | |
n.burnin = 1000, | |
n.thin = 1) | |
# print results | |
print(out) | |
# check convergence | |
traceplot(out, ask=F) | |
# nice plots | |
library(lattice) | |
jagsfit.mcmc <- as.mcmc(out) | |
densityplot(jagsfit.mcmc) | |
# for Bayesian Lasso, see | |
# https://people.eecs.berkeley.edu/~jordan/courses/260-spring09/other-readings/park-casella.pdf | |
# http://www2.stat-athens.aueb.gr/~jbn/papers2/23b_Lykou_Ntzoufras_2011_Wires_WinBUGS_final.pdf | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment