Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Created July 14, 2022 04:58
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
logits = jnp.array([0.50,0.60,0.70,0.30,0.25])
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.huber_loss(logits,labels)
# DeviceArray([0.045 , 0.045 , 0.17999998, 0.005 , 0.00125 ], dtype=float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment