Skip to content

Instantly share code, notes, and snippets.

@erinkhoo
Created March 28, 2021 09:21
Show Gist options
  • Save erinkhoo/f6030ec9e43216e035e1cffb44fc43de to your computer and use it in GitHub Desktop.
Save erinkhoo/f6030ec9e43216e035e1cffb44fc43de to your computer and use it in GitHub Desktop.
ADAM Gradient Descent
def adam_gradient_descent(a, b, x, y, lr=1e-5, b1=0.9, b2=0.999, epsilon=1e-4):
prev_error = 0
m_a = v_a = m_b = v_b = 0
moment_m_a = moment_v_a = moment_m_b = moment_v_b = 0
t = 0
error = np.array([])
while True:
gradient_a, gradient_b = gradient(a, b, x, y)
if abs(mse(a, b, x, y) - prev_error) < epsilon:
break
t += 1
prev_error = mse(a, b, x, y)
error = np.insert(error, len(error), prev_error)
m_a = b1 * m_a + (1-b1)*gradient_a
v_a = b2 * v_a + (1-b2)*gradient_a**2
m_b = b1 * m_b + (1-b1)*gradient_b
v_b = b2 * v_b + (1-b2)*gradient_b**2
moment_m_a = m_a / (1-b1**t)
moment_v_a = v_a / (1-b2**t)
moment_m_b = m_b / (1-b1**t)
moment_v_b = v_b / (1-b2**t)
a -= (lr*moment_m_a) / (moment_v_a**0.5 + 1e-8)
b -= (lr*moment_m_b) / (moment_v_b**0.5 + 1e-8)
return a, b, error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment