Skip to content

Instantly share code, notes, and snippets.

@alvinwan
Created December 4, 2017 11:59
Show Gist options
  • Save alvinwan/ea990e63bae78f4945b92bb8b9f8ca24 to your computer and use it in GitHub Desktop.
Save alvinwan/ea990e63bae78f4945b92bb8b9f8ca24 to your computer and use it in GitHub Desktop.
Plotting errors for ridge regression v. ordinary least squares as magnitude in noise increases
import numpy as np
import matplotlib
matplotlib.use('Agg') # wizardry to prevent Tkinter error
import matplotlib.pyplot as plt
n = 100
line = lambda x: 2*x + 3 # true line
e = 10
reg = 1e2
es = [i * 40 for i in range(9)]
ols_errors = []
rr_errors = []
plt.figure(figsize=(14,8))
plt.subplots_adjust(wspace=0.2, hspace=0.3)
plt_num = 210
for plt_idx, e in enumerate(es):
np.random.seed(0)
# 1. Create n random samples and plot
xs = list(range(n)) # create xs from 0 to 99
ys = [line(x) + np.random.random() * e - (e/2.) for x in xs] # create ys with 0-mean n$
# 2. Create A and y
A = np.array([[x, 1] for x in xs])
y = np.array([[y] for y in ys])
true_y = np.array([[line(x)] for x in xs])
# 3. Solve for optimal model w and plot predicted ys
I = np.eye(A.shape[1])
ols_w = np.linalg.inv(A.T.dot(A)).dot(A.T.dot(y))
rr_w = np.linalg.inv(A.T.dot(A) + reg * I).dot(A.T.dot(y))
plt.subplot(3, 3, plt_idx + 1)
plt.title('%2.fx noise' % e)
plt.scatter(xs, ys, s=2)
plt.plot(xs, [line(x) for x in xs], label='true')
plt.plot(xs, A.dot(ols_w), label='ols')
#plt.plot(xs, A.dot(rr_w), label='rr')
plt.savefig('errors_ols_v_ridge_points.png')
ols_errors.append(np.linalg.norm(A.dot(ols_w) - true_y))
rr_errors.append(np.linalg.norm(A.dot(rr_w) - true_y))
plt.figure()
plt.plot(es, ols_errors, label='ols') # plot ols performance
plt.plot(es, rr_errors, label='rr') # plot ridge regression performance
plt.legend()
plt.savefig('errors_ols_v_ridge.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment