Skip to content

Instantly share code, notes, and snippets.

@syadlowsky
Created March 19, 2019 18:42
Show Gist options
  • Save syadlowsky/19cf07030959654d2857f4b9a608e1b9 to your computer and use it in GitHub Desktop.
Save syadlowsky/19cf07030959654d2857f4b9a608e1b9 to your computer and use it in GitHub Desktop.
simple.logistic = function(x, y, w, iters=30, l1_penalty=0) {
d = ncol(x)
n = nrow(x)
x_c = colMeans(x)
x = (x - matrix(rep(x_c, n), nrow=n, byrow = T))
x_s = sqrt(colMeans(x^2))
x = (x / matrix(rep(x_s, n), nrow=n, byrow = T))
beta = rep(0,d)
p = weighted.mean(y, w = w)
beta_0 = log(p) - log(1-p)
for (iter_ in 1:iters) {
pred = 1 / (1 + exp(-x %*% beta - beta_0))
weights = w * pred * (1 - pred)
rw = matrix(rep(weights, d), ncol=d)
ob = beta
XTX = t(x) %*% (rw * x)
XTX = XTX + sum(w) * l1_penalty * diag(1/(abs(beta)+1e-5))
beta = beta + qr.solve(XTX, t(x) %*% (w * (y - pred)))
if (mean((beta - ob)^2)<1e-6) {
break
}
}
obj = list(beta=beta, beta_0=beta_0, x_c = x_c, x_s = x_s)
class(obj) = "simple.logistic"
return(obj)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment