Skip to content

Instantly share code, notes, and snippets.

@AndreiBarsan
Created April 23, 2020 03:00
Show Gist options
  • Save AndreiBarsan/801b4e8561c9e7da02745c332fb3cf3b to your computer and use it in GitHub Desktop.
Save AndreiBarsan/801b4e8561c9e7da02745c332fb3cf3b to your computer and use it in GitHub Desktop.
Compares hacky gradient descent to Nelder-Mead search (simplex search)
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
matplotlib.rcParams['xtick.direction'] = 'out'
matplotlib.rcParams['ytick.direction'] = 'out'
def rosenbrock(x, y, a=1.0, b=100.0):
return (a - x) ** 2 + b * (y - x * x) ** 2
def rosenbrock_grad(x, y, a=1.0, b=100.0):
return [
-2.0 * (a - x) + 2 * b * (y - x * x) * (-2.0) * x,
2 * b * (y - x * x),
]
def gradient_descent(alpha=0.005, eps=1e-8):
"""Simple GD with momentum."""
theta = [-1.0, 2.0]
vals = []
thetas = []
thetas.append(list(theta))
mom = [0.0, 0.0]
for i in range(1000):
val = rosenbrock(theta[0], theta[1])
vals.append(val)
if len(vals) > 1 and abs(vals[-1] - vals[-2]) < eps:
break
grad = rosenbrock_grad(theta[0], theta[1])
mom[0] = mom[0] * 0.9 + grad[0] * 0.1
mom[1] = mom[1] * 0.9 + grad[1] * 0.1
theta[0] -= alpha * mom[0]
theta[1] -= alpha * mom[1]
thetas.append(list(theta))
return vals, thetas
def contour(function, x_vals=None, y_vals=None):
if x_vals is None:
x_vals = np.linspace(-2.0, 2.0, num=500)
if y_vals is None:
y_vals = np.linspace(-5.0, 5.0, num=500)
xx, yy = np.meshgrid(x_vals, y_vals)
zz = function(xx, yy)
contour_plot = plt.contourf(xx, yy, zz, levels=[0.1, 0.5, 1.0, 10.0, 100.0, 250, 1000.0, 3000.0])
# plt.imshow(zz, extent=[-5, 5, -5, 5])
plt.colorbar()
# plt.clabel(contour_plot, fontsize=9, inline=1)
vals, thetas = gradient_descent()
thetas = np.array(thetas)
contour(rosenbrock)
plt.scatter(thetas[:, 0], thetas[:, 1])
plt.title("Final val {} | {} steps".format(vals[-1], len(vals)))
print(vals[-1])
print(np.array(vals).min())
def nelder_mead(fn_2d, **kwargs):
simplex = np.array([
[1.0, -3.0],
[2.0, 0.0],
[1.0, 4.0],
])
alpha = kwargs.get('alpha', 1.0)
gamma = kwargs.get('gamma', 2.0)
rho = kwargs.get('rho', 0.5)
sigma = kwargs.get('sigma', 0.5)
def draw():
plt.figure()
contour(fn_2d)
xs, ys = zip(*np.vstack((simplex, simplex[:1, :])))
plt.plot(xs, ys)
# TODO proper termination condition
for iteration in range(20):
draw()
costs = fn_2d(simplex[:, 0], simplex[:, 1])
sorted_args = np.argsort(costs)
sorted_costs = costs[sorted_args]
sorted_simplex = simplex[sorted_args, :]
centroid = np.mean(sorted_simplex[:-1, :], axis=0)
reflected = centroid + alpha * (centroid - sorted_simplex[-1, :])
# Only for display purposes
centroid_val = fn_2d(centroid[0], centroid[1])
# print(centroid_val)
reflected_val = fn_2d(reflected[0], reflected[1])
if reflected_val <= sorted_costs[0]:
# Expand
expanded = centroid + gamma * (reflected - centroid)
expanded_val = fn_2d(expanded[0], expanded[1])
if expanded_val < reflected_val:
# Expansion worked!
sorted_simplex[-1, :] = expanded
else:
# Expansion failed
sorted_simplex[-1, :] = reflected
elif sorted_costs[0] < reflected_val < sorted_costs[1]:
# OK point
sorted_simplex[-1, :] = reflected
else:
# Try to contract
contraction = centroid + rho * (sorted_simplex[-1, :] - centroid)
contraction_val = fn_2d(contraction[0], contraction[1])
if contraction_val < sorted_costs[-1]:
# Eh, at least we didn't make things worse...
sorted_simplex[-1, :] = contraction
else:
# Crap... re-compute simplex points around best
sorted_simplex[1:, :] = sorted_simplex[0, :] + sigma * (sorted_simplex[1:, :] - sorted_simplex[0, :])
simplex = np.array(sorted_simplex)
plt.title("it = {} | centroid val = {:.4f}".format(iteration, centroid_val))
plt.savefig('/tmp/amoeba-{:04d}.png'.format(iteration))
nelder_mead(rosenbrock)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment