Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Last active December 18, 2023 13:59
Show Gist options
  • Save mwitiderrick/75e4f5152e4a262d989ac52128b3b365 to your computer and use it in GitHub Desktop.
Save mwitiderrick/75e4f5152e4a262d989ac52128b3b365 to your computer and use it in GitHub Desktop.
def evaluate_model(state, batch):
"""Evaluate on the validation set."""
test_imgs, test_lbls = batch
metrics = eval_step(state, test_imgs, test_lbls)
metrics = jax.device_get(metrics)
metrics = jax.tree_map(lambda x: x.item(), metrics)
return metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment