Skip to content

Instantly share code, notes, and snippets.

@shawwn
Created February 9, 2023 18:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shawwn/d0be4908bd4f5cab7636f30e727ca495 to your computer and use it in GitHub Desktop.
Save shawwn/d0be4908bd4f5cab7636f30e727ca495 to your computer and use it in GitHub Desktop.
AdamSP optimizer
def lerp(a, b, t):
return (b - a) * t + a
@optimizer
def adamsp(step_size=1e-1, b1=0.5):
"""Construct optimizer triple for AdamSP.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar (default 1e-1).
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.5).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
# m0 = jnp.zeros_like(x0, dtype=jnp.float32)
m0 = jnp.zeros_like(x0)
return x0, m0
def update(i, g, state):
x, m = state
# Apply acceleration to velocity.
m = lerp(g, m, b1) # First moment estimate.
# m is velocity (a vector)
velocity = m
# Take the sign of velocity to get a normalized direction.
normal = jnp.sign(velocity)
# Use the magnitude of the gradient as the speed.
speed = jnp.abs(g)
# Push the weights in the direction of the (normalized) gradient, scaled by speed.
scale = -step_size(i) * speed
offset = normal * scale
x = x + offset
# Return the new state.
return x, m
def get_params(state):
x, m = state
return x
return init, update, get_params
@shawwn
Copy link
Author

shawwn commented Feb 15, 2023

It turns out that Google just published this as Lion: https://twitter.com/theshawwn/status/1625681629074137088

Pretty cool!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment