Skip to content

Instantly share code, notes, and snippets.

@BrambleXu
Created June 16, 2019 14:02
Show Gist options
  • Save BrambleXu/1adbbceff0da62a6d5193e5aefb00952 to your computer and use it in GitHub Desktop.
Save BrambleXu/1adbbceff0da62a6d5193e5aefb00952 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
# random seed to make sure reimplement
np.random.seed(0)
# the real model line
def g(x):
return 0.1 * (x + x**2 + x**3)
# add noise to the model for faking data
train_x = np.linspace(-2, 2, 8)
train_y = g(train_x) + np.random.randn(len(train_x)) * 0.05
# # plot
# x = np.linspace(-2, 2, 100)
# plt.plot(train_x, train_y, 'o')
# plt.plot(x, g(x), linestyle='dashed')
# plt.ylim(-1, 2)
# plt.show()
# standardization
mu = train_x.mean()
std = train_x.std()
def standardizer(x):
return (x - mu) / std
std_x = standardizer(train_x)
# get matrix
def to_matrix(x):
return np.vstack([
np.ones(x.size),
x,
x ** 2,
x ** 3,
x ** 4,
x ** 5,
x ** 6,
x ** 7,
x ** 8,
x ** 9,
x ** 10,
]).T
mat_x = to_matrix(std_x)
# initialize parameter
theta = np.random.randn(mat_x.shape[1])
# predict function
def f(x):
return np.dot(x, theta)
# cost function
def E(x, y):
return 0.5 * np.sum((y - f(x))**2)
# learning rate
ETA = 1e-4
# regularization parameter
LAMBDA = 1
# initialize difference between two epochs
diff = 1
# initialize error
error = E(mat_x, train_y)
######## training without regularization ########
while diff > 1e-6:
# notice we don't use regularization for theta 0
reg_term = LAMBDA * np.hstack([0, theta[1:]])
# update parameter
theta = theta - ETA * (np.dot(mat_x.T, f(mat_x) - train_y) + reg_term)
current_error = E(mat_x, train_y)
diff = error - current_error
error = current_error
# save parameters
theta2 = theta
########## plot the line with regularization ##########
plt.ylim(-1, 2)
plt.plot(std_x, train_y, 'o')
z = standardizer(np.linspace(-2, 2, 100))
theta = theta2
plt.plot(z, f(to_matrix(z)))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment