Skip to content

Instantly share code, notes, and snippets.

@aotimme
Last active August 22, 2018 16:23
Show Gist options
  • Save aotimme/74f0fd08313d21cd5f24 to your computer and use it in GitHub Desktop.
Save aotimme/74f0fd08313d21cd5f24 to your computer and use it in GitHub Desktop.
ADMM for LASSO
import numpy as np
l2 = lambda x: np.sqrt(np.sum(np.square(x)))
def lasso_admm(X, y, lambduh, rho=1.0, tol=1e-10, maxiter=1000):
i = 0
n, p = X.shape
inv = np.linalg.inv(X.T.dot(X) + rho * np.eye(p))
xy = X.T.dot(y)
z = np.zeros(p)
a = np.zeros(p)
beta_old = np.zeros(p)
for it in xrange(maxiter):
beta = inv.dot(xy + rho * z - a)
if l2(beta_old - beta) < tol:
break
betaarho = beta + a / rho
z = np.abs(betaarho) - lambduh / rho
z[z < 0] = 0
z *= np.sign(betaarho)
a += rho * (beta - z)
beta_old = beta.copy()
return beta
if __name__ == '__main__':
np.random.seed(20)
beta = np.array([3.0, -2.0, 0.0, 0.0, 2.0, -1.0, 0.0, 0.0])
p = len(beta)
n = 1000
X = np.random.multivariate_normal(np.zeros(p), np.eye(p), n)
y = X.dot(beta) + np.random.normal(0, 1, n)
lambduh = 10.0
print "beta:"
print beta
print "beta lasso:"
beta_lasso = lasso_admm(X, y, lambduh, rho=1.0, tol=1e-10, maxiter=1000)
print np.round(beta_lasso, 3)
#print "beta ols:"
#beta_ols = np.linalg.solve(X.T.dot(X), X.T.dot(y))
#print np.round(beta_ols, 3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment