Skip to content

Instantly share code, notes, and snippets.

@jaymody
Last active December 9, 2022 22:22
Show Gist options
  • Save jaymody/04a98633d53a4d140e46060edfc22bf3 to your computer and use it in GitHub Desktop.
Save jaymody/04a98633d53a4d140e46060edfc22bf3 to your computer and use it in GitHub Desktop.
Quick notebook that shows that `nll_loss`, `cross_entropy`, and `kl_div` are equivalent loss functions for categorical data.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def cross_entropy(logits, target):\n",
" return -logits.log_softmax(axis=-1)[target]\n",
"\n",
"def nll_loss(log_probs, target):\n",
" return -log_probs[target]\n",
"\n",
"def kl_div(log_probs, target_probs, eps=1e-9):\n",
" return (target_probs * ((target_probs+eps).log() - log_probs)).sum()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"logits = torch.randn(5)\n",
"log_probs = logits.log_softmax(axis=-1)\n",
"\n",
"target = torch.tensor(2)\n",
"target_one_hot = torch.tensor([0, 0, 1, 0, 0])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6640)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_entropy(logits, target)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6640)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.nn.functional.cross_entropy(logits.unsqueeze(0), target.unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6640)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nll_loss(log_probs, target)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6640)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.nn.functional.nll_loss(log_probs.unsqueeze(0), target.unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6640)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kl_div(log_probs, target_one_hot)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6640)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.nn.functional.kl_div(log_probs.unsqueeze(0), target_one_hot.unsqueeze(0), reduction=\"batchmean\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.10 64-bit ('3.9.10')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fd09b19eb83f586d348350b5c89c7a987a0d039b02a538583d56ff9c88f80cb0"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment