Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save EdisonLeeeee/f67205683603f9c11b2940c71557410b to your computer and use it in GitHub Desktop.
Save EdisonLeeeee/f67205683603f9c11b2940c71557410b to your computer and use it in GitHub Desktop.
PyTorch equivalence for softmax_cross_entropy_with_logits
import torch
import tensorflow as tf
def softmax_cross_entropy_with_logits(labels, logits, dim=-1):
return (-labels * F.log_softmax(logits, dim=dim)).sum(dim=dim)
logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
out_th = softmax_cross_entropy_with_logits(torch.tensor(labels), torch.tensor(logits))
########### Equivalent to ############
out_tf = tf.nn.softmax_cross_entropy_with_logits(labels, logits)
print(out_th)
print(out_tf)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment