Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Created July 22, 2024 15:39
Show Gist options
  • Save OhadRubin/189d6e06f17969d39e2127e1e44bb88c to your computer and use it in GitHub Desktop.
Save OhadRubin/189d6e06f17969d39e2127e1e44bb88c to your computer and use it in GitHub Desktop.
the inverse_square_root_schedule from PaliGemma and Scaling Vision Transformers
import numpy as np
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
def inverse_square_root_schedule(
max_lr: float,
warmup_steps: int,
):
def schedule(count):
frac = (count+1)/warmup_steps
A = jnp.sqrt(1/frac)*jnp.where(frac > 1,1,0)
B = frac*jnp.where(frac<=1,1,0)
return max_lr*(A + B)
return schedule
plt.figure(figsize=(10,3.5))
x = np.arange(1000)
max_val = float(inverse_square_root_schedule(1, 100)(500))
plt.plot(x, jax.device_get(inverse_square_root_schedule(max_val, 500)(x)))
plt.plot(x, jax.device_get(inverse_square_root_schedule(1, 100)(x)))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment