Skip to content

Instantly share code, notes, and snippets.

@mathurinm
Created December 17, 2020 14:40
Show Gist options
  • Save mathurinm/587bab99909976a04a908a0aaf13f41b to your computer and use it in GitHub Desktop.
Save mathurinm/587bab99909976a04a908a0aaf13f41b to your computer and use it in GitHub Desktop.
from celer import Lasso
import matplotlib.pyplot as plt
from scipy.optimize import fmin_bfgs
from numpy.linalg import norm
import numpy as np
from celer.datasets import make_correlated_data
import seaborn as sns
c_list = sns.color_palette("colorblind")
A, b, x_true = make_correlated_data(
n_samples=40, n_features=50, rho=0.5, random_state=0)
alpha_max = norm(A.T @ b, ord=np.inf)
alpha = alpha_max / 20
def obj(x):
return norm(A @ x - b) ** 2 / 2. + alpha * norm(x, ord=1)
def grad(x):
return A.T @ (A @ x - b) + alpha * np.sign(x)
x_bfgs = fmin_bfgs(obj, np.zeros(A.shape[1]), fprime=grad, gtol=0,
maxiter=2_000)
x_cd = Lasso(fit_intercept=False, alpha=alpha / len(b)).fit(A, b).coef_
plt.figure()
m, s, _ = plt.stem(np.where(x_cd)[0], x_cd[x_cd != 0], label="CD")
plt.setp([m, s], color=c_list[1], linewidth=6)
m, s, _ = plt.stem(np.where(x_bfgs)[0], x_bfgs[x_bfgs != 0], label="BFGS")
plt.setp([m, s], color=c_list[0])
print(f"10th position, BFGS: {x_bfgs[9]}, CD: {x_cd[9]}")
plt.legend()
plt.show(block=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment