Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Last active July 14, 2022 04:44
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
import jax
def custom_sigmoid_binary_cross_entropy(logits, labels):
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
return -labels * log_p - (1. - labels) * log_not_p
custom_sigmoid_binary_cross_entropy(0.5,0.0)
# DeviceArray(0.974077, dtype=float32, weak_type=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment