Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Last active December 18, 2023 13:57
Show Gist options
  • Save mwitiderrick/24410be9d7bebe99c464762134572f12 to your computer and use it in GitHub Desktop.
Save mwitiderrick/24410be9d7bebe99c464762134572f12 to your computer and use it in GitHub Desktop.
def compute_loss(params,images,labels):
logits = CNN().apply({'params': params}, images)
loss = cross_entropy_loss(logits=logits, labels=labels)
return loss, logits
@jax.jit
def train_step(state, batch):
"""Train for a single step."""
images, labels = batch
(_, logits), grads = jax.value_and_grad(compute_loss, has_aux=True)(state.params,images,labels)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=labels)
return state, metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment