Skip to content

Instantly share code, notes, and snippets.

@shawwn
Last active February 15, 2023 19:48
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shawwn/97811b6819a444ce92187532743a920e to your computer and use it in GitHub Desktop.
Save shawwn/97811b6819a444ce92187532743a920e to your computer and use it in GitHub Desktop.
Reformulating Adam optimizer to gain an intuition about what it's doing.
def lerp(a, b, t):
return (b - a) * t + a
def bias(i, x, beta):
return 1 - jnp.asarray(beta, x.dtype) ** (i + 1)
@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8) -> OptimizerResult:
"""Construct optimizer triple for Adam.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
M0 = jnp.zeros_like(x0)
return x0, m0, M0
def update(i, g, state):
# Think of the gradient like a force (in the Newtonian physics sense).
# It accelerates each weight in the direction of the gradient.
# The larger the gradient, the faster the acceleration.
#
# Per-weight state:
#
# Lowercase letters represent values.
# Uppercase letters represent squared values.
#
# x is the position (i.e. the current weight value).
#
# g is the gradient.
# G is the gradient squared.
#
# m is a measurement of the gradient over time (i.e. a moving average of the gradient).
# M is a measurement of the squared gradient over time (i.e. a moving average of the squared gradient).
#
# b1 controls how quickly m accelerates towards g. Defaults to 0.9.
# b2 contorls how quickly M accelerates towards G. Defaults to 0.999.
#
x, m, M = state
G = jnp.square(g)
# Accelerate each weight by pushing each weight's velocity along its gradient vector.
m = lerp(g, m, b1) # Push the velocity (m) toward its gradient (i.e. first moment estimate).
M = lerp(G, M, b2) # Push the squared velocity (M) toward its squared gradient (i.e. second moment estimate).
m_ = m / bias(i, m, b1) # Bias correction, since the moving averages start at zero.
M_ = M / bias(i, M, b2)
# Calculate each weight's new velocity by measuring the gradient (m) and squared gradient (M) over time.
# Velocity is a change in position (dx) over a change in time (dt).
dx = m_ / (jnp.sqrt(M_) + eps) # A change in position (dx).
dt = -step_size(i) # A change in time (dt).
# Calculate each weight's new position by pushing each weight's position along its velocity vector.
# The position offset is calculated by multiplying the change in position by the change in time (dx * dt).
# Since it's an offset, we can just add it to the old position to get the new position.
x = x + dx * dt
# We're done; return the new state.
return x, m, M
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
# A second revision, with fewer comments but more descriptive variable names:
@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8) -> OptimizerResult:
"""Construct optimizer triple for Adam.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
M0 = jnp.zeros_like(x0)
return x0, m0, M0
def update(i, g, state):
x, m, M = state
G = jnp.square(g)
# Calculate acceleration.
m = lerp(g, m, b1) # First moment estimate.
M = lerp(G, M, b2) # Second moment estimate.
m_ = m / bias(i, m, b1)
M_ = M / bias(i, M, b2)
# m_ is velocity (a vector)
# M_ is squared speed (a directionless quantity)
# sqrt(M_) is average speed over time
velocity = m_
speed = jnp.sqrt(M_)
# Divide velocity by speed to get a normalized direction.
normal = velocity / (speed + eps)
# Push the weights in the direction of the (normalized) gradient.
scale = -step_size(i)
offset = normal * scale
x = x + offset
# Return the new state.
return x, m, M
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
# The original Adam code, for comparison. Identical to the above two versions; it's what I started with.
@optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8) -> OptimizerResult:
"""Construct optimizer triple for Adam.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
v0 = jnp.zeros_like(x0)
return x0, m0, v0
def update(i, g, state):
x, m, v = state
m = (1 - b1) * g + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
return x, m, v
def get_params(state):
x, _, _ = state
return x
return init, update, get_params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment