Skip to content

Instantly share code, notes, and snippets.

@kaneplusplus
Last active January 17, 2016 20:40
Show Gist options
  • Save kaneplusplus/708d5d9ce43ddc4bba2b to your computer and use it in GitHub Desktop.
Save kaneplusplus/708d5d9ce43ddc4bba2b to your computer and use it in GitHub Desktop.
A minimal glmnet implementation in R
soft_thresh = function(x, g) {
x = as.vector(x)
w1 = which(g >= abs(x))
w2 = which(g < abs(x) & x > 0)
w3 = which(g < abs(x) & x < 0)
ret = x
ret[w1] = 0
ret[w2] = x[w2]-g
ret[w3] = x[w3]+g
ret
}
glmnet_ref = function(X, y, lambda, alpha, family=binomial, maxit=10, tol=1e-08)
{
beta = matrix(rep(0,ncol(X)), ncol=1)
for(j in 1:maxit)
{
beta_outer_old = beta
eta = as.matrix(X %*% beta)
g = family()$linkinv(eta)
gprime = family()$mu.eta(eta)
z = eta + (y - g) / gprime
W = as.vector(gprime^2 / family()$variance(g))
wx_norm = colSums(W*X^2)
quad_loss = Inf
for (k in 1:maxit) {
beta_inner_old = beta
for (l in 1:length(beta)) {
beta[l] = soft_thresh(sum(W*X[,l]*(z - X[,-l] %*% beta_inner_old[-l])),
sum(W)*lambda*alpha)
}
beta = beta / (wx_norm + lambda*(1-alpha))
quad_loss = -1/2/nrow(X) * sum(W*(z - X %*% beta)^2) +
lambda * (1-alpha) * sum(beta^2)/2 + alpha * sum(beta)
if (quad_loss > quad_loss_old) quad_loss_old = quad_loss
else break
}
if (sqrt(as.double(crossprod(beta-beta_outer_old))) < tol) break
}
list(beta=beta,iterations=j)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment