Skip to content

Instantly share code, notes, and snippets.

@phydev
Last active October 21, 2020 08:44
Show Gist options
  • Save phydev/bef1fdec35ea5ad55b1387d9a1d2d426 to your computer and use it in GitHub Desktop.
Save phydev/bef1fdec35ea5ad55b1387d9a1d2d426 to your computer and use it in GitHub Desktop.
Compute LASSO with tensorflow gradient.
"""
LASSO regression (L1 regularization) with gradient descent
TODO: estimate intercept
phydev.github.io
"""
def predict(X, beta):
"""
predict the regression
"""
return tf.squeeze(X@beta)
def lagrangian_lasso(X, y, beta, lambda_):
"""
lagrangian form of the loss function
"""
yhat = predict(X, beta)
lagrangian = tf.reduce_mean(tf.pow(y - yhat, 2))/2 + lambda_*(tf.reduce_sum(tf.abs(beta)))
return lagrangian
def compute_gradients(X, Y, beta, lambda_):
"""
compute the gradients
"""
with tf.GradientTape() as tape:
loss = lagrangian_lasso(X, Y, beta, lambda_)
gradients = tape.gradient(loss, [beta])
return gradients
if __name__ == '__main__':
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso
N = 100 # number of points
p = 1 # number of features
X = np.random.rand(N, p)
X[:, 0] += np.linspace(1, N, N) + 10*(np.random.rand(N)-0.5) # synthetic data
y = tf.constant(np.linspace(1, N, N) + 10*(np.random.rand(N)-0.5))
X = tf.constant(X)
beta = tf.Variable(np.zeros((p,1))) # coefficients
lambda_ = tf.constant(1.0, dtype=np.float64) # lagrange multiplier
steps = 500
learning_rate = .0001
printout = "Step {step} - loss: {loss:2f}, beta: {beta:2f} \n"
grad = tf.Variable(np.zeros((p,1)))
# minimisation
for step in range(0, steps + 1):
grad = compute_gradients(X, y, beta, lambda_=lambda_)[0]
beta.assign_sub(tf.multiply(grad, learning_rate))
if step % 50 == 0:
loss = lagrangian_lasso(X, y, beta, lambda_=lambda_)
print(printout.format(step=step, loss=loss, beta=beta[0,0].numpy()))
model = Lasso(alpha=lambda_)
model.fit(X, y)
plt.plot(X[:, 0], X[:, 0]*model.coef_[0] + model.intercept_,
label='Sklearn LASSO', color='blue', linewidth=4)
plt.plot(X[:, 0], X[:, 0]*beta[0,0].numpy(),
label='Scratch LASSO', ls=':', color='red', linewidth=4)
plt.scatter(X[:, 0], y, label='Test dataset', color='grey')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment