Skip to content

Instantly share code, notes, and snippets.

@kamath
Last active February 5, 2021 23:13
Show Gist options
  • Save kamath/d2416fb37d6bfe3f0bae033611846db5 to your computer and use it in GitHub Desktop.
Save kamath/d2416fb37d6bfe3f0bae033611846db5 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
X = np.array([1.2, 3.2, 5.1, 3.5, 2.6]).reshape(-1, 1)
y = np.array([7.8, 1.2, 6.4, 2.6, 8.1])
ints = np.ones(shape=y.shape)[..., None]
X = np.concatenate((X, ints), 1)
# When l = 0, it's RSS, otherwise we can specify lambda via l term
regress = lambda l: np.linalg.inv(X.T.dot(X) + (l**2)*np.identity(2)).dot(X.T).dot(y)
x = np.linspace(1, 6)
plt.scatter(X[:, 0].reshape(1, -1), y)
for l in [0, 1, 10]:
m, b = regress(l)
print(m, b, l)
plt.plot(x, m*x + b, label=f'lambda={l}')
plt.legend()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment