Skip to content

Instantly share code, notes, and snippets.

@RemyLau
Last active June 4, 2024 12:57
Show Gist options
  • Save RemyLau/c0892869bd769c421364a230d33a129a to your computer and use it in GitHub Desktop.
Save RemyLau/c0892869bd769c421364a230d33a129a to your computer and use it in GitHub Desktop.
Understanding the difference between cross entropy and negative log-likelihood loss as implemented in PyTorch
import torch
import torch.nn.functional as F
torch.manual_seed(0)
# Binary setting ##############################################################
print(f"{'Setting up binary case':-^80}")
z = torch.randn(5)
yhat = torch.sigmoid(z)
y = torch.Tensor([0, 1, 1, 0, 1])
print(f"{z = }")
print(f"{yhat = }")
print(f"{y = }")
print("-" * 80)
# First compute the negative log likelihoods using the derived formula
l = -(y * yhat.log() + (1 - y) * (1 - yhat).log())
print(f"{l = }")
# Observe that BCELoss and BCEWithLogitsLoss can produce the same results
l_BCELoss_nored = torch.nn.BCELoss(reduction="none")(yhat, y)
l_BCEWithLogitsLoss_nored = torch.nn.BCEWithLogitsLoss(reduction="none")(z, y)
print(f"{l_BCELoss_nored = }")
print(f"{l_BCEWithLogitsLoss_nored = }")
print("-" * 80)
# The default reduction is mean
l_mean = l.mean()
l_BCELoss = torch.nn.BCELoss()(yhat, y)
l_BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()(z, y)
print(f"{l_mean = }")
print(f"{l_BCELoss = }")
print(f"{l_BCEWithLogitsLoss = }")
print("-" * 80)
# Optionally, one can use equivalent functions from torch.nn.functional
print(f"{torch.nn.functional.binary_cross_entropy(yhat, y) = }")
print(f"{torch.nn.functional.binary_cross_entropy_with_logits(z, y) = }")
# Can recover BCELoss using NLLLoss
# Note that the first column is the negative class if we want to use y=0 for
# negative and y=1 for positive.
yhat_mat = torch.vstack((1 - yhat, yhat)).T
print(f"{torch.nn.functional.nll_loss(yhat_mat.log(), y.long()) = }")
print("=" * 80)
# Multiclass setting ##########################################################
print(f"{'Setting up multiclass case':-^80}")
z2 = torch.randn(5, 3)
yhat2 = torch.softmax(z2, dim=-1)
y2 = torch.Tensor([0, 2, 1, 1, 0]).long()
print(f"{z2 = }")
print(f"{yhat2 = }")
print(f"{y2 = }")
print("-" * 80)
# First compute the negative log likelihoods using the derived formulat
l2 = -yhat2.log()[torch.arange(5), y2] # masking the correct entries
print(f"{l2 = }")
print(-torch.log_softmax(z2, dim=-1)[torch.arange(5), y2])
# Observe that NLLLoss and CrossEntropyLoss can produce the same results
l2_NLLLoss_nored = torch.nn.NLLLoss(reduction="none")(yhat2.log(), y2)
l2_CrossEntropyLoss_nored = torch.nn.CrossEntropyLoss(reduction="none")(z2, y2)
print(f"{l2_NLLLoss_nored = }")
print(f"{l2_CrossEntropyLoss_nored = }")
print("-" * 80)
# The default reduction is mean
l2_mean = l2.mean()
l2_NLLLoss = torch.nn.NLLLoss()(yhat2.log(), y2)
l2_CrossEntropyLoss = torch.nn.CrossEntropyLoss()(z2, y2)
print(f"{l2_mean = }")
print(f"{l2_NLLLoss = }")
print(f"{l2_CrossEntropyLoss = }")
print("-" * 80)
# Optionally, one can use equivalent functions from torch.nn.functional
print(f"{torch.nn.functional.nll_loss(yhat2.log(), y2) = }")
print(f"{torch.nn.functional.cross_entropy(z2, y2) = }")
print("=" * 80)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment