Skip to content

Instantly share code, notes, and snippets.

@murphyk
Last active May 21, 2021 20:43
Show Gist options
  • Save murphyk/c1871bfa8c40dc1a93258e52afa2c5d1 to your computer and use it in GitHub Desktop.
Save murphyk/c1871bfa8c40dc1a93258e52afa2c5d1 to your computer and use it in GitHub Desktop.
JL training pseudocode
model = jaxLightningModule()
def fit(model):
opt = model.configure_optimizers()
opt_state = opt.init(model.params)
for iter
for batch
losses = model.step(batch)
grads = jax.grad(losses, model.params)
grad = model.reduce_gradients(grads)
param_update, opt_state = opt.update(grad, opt_state, model.params)
model.params = optax.apply_update(model.params, param_update)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment