Last active
January 14, 2021 13:08
-
-
Save phuocphn/640428fa0438145c5132a16a5c9a9869 to your computer and use it in GitHub Desktop.
KL Divergence (Pytorch)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
As with NLLLoss , the input given is expected to contain log-probabilities… | |
The targets are given as probabilities (i.e. without taking the logarithm). | |
https://discuss.pytorch.org/t/kldivloss-returns-negative-value/62148 | |
https://discuss.pytorch.org/t/kl-divergence-produces-negative-values/16791/16 | |
https://discuss.pytorch.org/t/kullback-leibler-divergence-loss-function-giving-negative-values/763/16 | |
''' | |
import torch | |
a = torch.log_softmax(torch.tensor([[0.8000, 0.1500, 0.0500]]), dim=1) | |
b = torch.softmax(torch.tensor([[0.8000, 0.1500, 0.0500]]), dim=1) | |
criterion = nn.KLDivLoss() | |
loss = criterion(a, b) | |
assert loss == 0.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment