Last active
October 28, 2017 17:32
-
-
Save yvan/1655d98b55587fe326a8ec88281567e2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from matplotlib import animation, rc | |
from IPython.display import HTML | |
# set ffmpeg path to installed path (for anmation) | |
plt.rcParams['animation.ffmpeg_path'] = '/usr/local/bin/ffmpeg' | |
# load the data | |
data = pd.read_csv('cars.csv') | |
w1 = -10 #np.random.normal(loc=10, scale=4, size=1) | |
x_data = data['speed'] | |
y_data = data['dist'] | |
def init(): | |
# start a random line | |
line.set_data([], []) | |
return (line,) | |
# create a graph | |
fig, ax = plt.subplots() | |
ax.set_xlim((0,30)) | |
ax.set_ylim((0,130)) | |
ax.set_xlabel('speed') | |
ax.set_ylabel('dist') | |
_ = ax.scatter(data['speed'], data['dist']) | |
line, = ax.plot([], [], lw=2) | |
# performs the gradient descent for linear regression | |
def calc_regression_simple(w1, frames, x_data, y_data): | |
learn_rate = 0.001 | |
ys = [] | |
for i in range(frames): | |
# get the gradient and update the parameter | |
w1_gradient = 2*np.mean(x_data*(w1*x_data - y_data)) | |
w1 = w1 - (w1_gradient*learn_rate) | |
# calculate the predictions from this new function | |
x = np.linspace(0,30) | |
ys.append(w1 * x) | |
return ys | |
ax.set_title('gradient descent regression values') | |
frames = 15 | |
ys = calc_regression_simple(w1, frames, x_data, y_data) | |
# use gradient descent to fit a regression | |
def animate_fit(i): | |
line.set_data(np.linspace(0,30), ys[i]) | |
return (line,) | |
# call the animator for our fitted line | |
anim = animation.FuncAnimation(fig, | |
animate_fit, | |
init_func=init, | |
frames=frames, | |
blit=True) | |
rc('animation', html='html5') | |
anim.save('gd.gif', writer='imagemagick', fps=2) | |
anim |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment