Skip to content

Instantly share code, notes, and snippets.

@tejaskhot
Created July 14, 2018 22:51
Show Gist options
  • Save tejaskhot/cf3d087ce4708c422e68b3b747494b9f to your computer and use it in GitHub Desktop.
Save tejaskhot/cf3d087ce4708c422e68b3b747494b9f to your computer and use it in GitHub Desktop.
Pytorch softmax cross entropy with logits
# pytorch function to replicate tensorflow's tf.nn.softmax_cross_entropy_with_logits
# works for soft targets or one-hot encodings
import torch
import torch.nn.functional as F
logits = model(input)
loss = torch.sum(- target * F.log_softmax(logits, -1), -1)
mean_loss = loss.mean()
@LiuNull
Copy link

LiuNull commented Apr 10, 2019

exm, it seems you want to replicate tensorflow's tf.nn.soft,ax_cross_entropy_with_logits, but you use F.log_softmax rather than F.softmax?

@zz-jacob
Copy link

zz-jacob commented Jun 2, 2019

exm, it seems you want to replicate tensorflow's tf.nn.soft,ax_cross_entropy_with_logits, but you use F.log_softmax rather than F.softmax?

Here log is for computing cross entropy.

@samyak0210
Copy link

Will this work if my logits shape is (32, 1, 128, 128) and target is also of same shape but with all values zero and one [x,y] = 1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment