Skip to content

Instantly share code, notes, and snippets.

@utahka
Last active September 27, 2017 05:48
Show Gist options
  • Save utahka/8752a40f6ebd452962119578495c12eb to your computer and use it in GitHub Desktop.
Save utahka/8752a40f6ebd452962119578495c12eb to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
plt.style.use("ggplot")
plt.rcParams["font.size"] = 13
plt.rcParams["figure.figsize"] = 16, 8
class GradientDescent(object):
def __init__(self, f, grad, init_x, n_iter=100, learning_ratio=0.01, delta=0.1**8):
self.f = f
self.grad = grad
self.n_iter = n_iter
self.learning_ratio = learning_ratio
self.init_x = init_x
self.delta = delta
self.hist_x = list()
self.hist_y = list()
self.hist_dx = list()
def solve(self):
f = self.f
grad = self.grad
x = self.init_x
for i in range(self.n_iter):
self.hist_x.append(x)
self.hist_y.append(f(x))
dx = self.learning_ratio*grad(x)
x -= dx
self.hist_dx.append(dx)
if dx < self.delta:
break
return self.hist_x, self.hist_y, self.hist_dx
def animate(nframe):
plt.clf()
plt.plot(x, f1(x))
plt.scatter(hist_x[nframe], hist_y[nframe], s=100)
plt.xlim(-8.6, 8.6)
plt.ylim(-20, 500)
plt.tight_layout()
if __name__ == "__main__":
x = np.linspace(-8.6, 8.6, 100)
f1 = lambda x: -(1/12)*x**4 + (25/2)*x**2
grad = lambda x: -(1/3)*x**3 + 25*x
gd = GradientDescent(f1, grad, init_x=8)
hist_x, hist_y, hist_dx = gd.solve()
max_frame = len(hist_x)
fig = plt.gcf()
ani = animation.FuncAnimation(fig, animate, frames=max_frame, interval=50)
ani.save('gradient_descent1.gif', writer='imagemagick', fps=5, dpi=64)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment