Created
February 9, 2023 18:44
-
-
Save shawwn/d0be4908bd4f5cab7636f30e727ca495 to your computer and use it in GitHub Desktop.
AdamSP optimizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def lerp(a, b, t): | |
return (b - a) * t + a | |
@optimizer | |
def adamsp(step_size=1e-1, b1=0.5): | |
"""Construct optimizer triple for AdamSP. | |
Args: | |
step_size: positive scalar, or a callable representing a step size schedule | |
that maps the iteration index to a positive scalar (default 1e-1). | |
b1: optional, a positive scalar value for beta_1, the exponential decay rate | |
for the first moment estimates (default 0.5). | |
Returns: | |
An (init_fun, update_fun, get_params) triple. | |
""" | |
step_size = make_schedule(step_size) | |
def init(x0): | |
# m0 = jnp.zeros_like(x0, dtype=jnp.float32) | |
m0 = jnp.zeros_like(x0) | |
return x0, m0 | |
def update(i, g, state): | |
x, m = state | |
# Apply acceleration to velocity. | |
m = lerp(g, m, b1) # First moment estimate. | |
# m is velocity (a vector) | |
velocity = m | |
# Take the sign of velocity to get a normalized direction. | |
normal = jnp.sign(velocity) | |
# Use the magnitude of the gradient as the speed. | |
speed = jnp.abs(g) | |
# Push the weights in the direction of the (normalized) gradient, scaled by speed. | |
scale = -step_size(i) * speed | |
offset = normal * scale | |
x = x + offset | |
# Return the new state. | |
return x, m | |
def get_params(state): | |
x, m = state | |
return x | |
return init, update, get_params |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It turns out that Google just published this as Lion: https://twitter.com/theshawwn/status/1625681629074137088
Pretty cool!