Skip to content

Instantly share code, notes, and snippets.

@jaymody
Last active December 11, 2022 18:39
Show Gist options
  • Save jaymody/723fa97369d35058eed341429b650761 to your computer and use it in GitHub Desktop.
Save jaymody/723fa97369d35058eed341429b650761 to your computer and use it in GitHub Desktop.
Studying entropy based loss functions.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Entropy Based Loss Functions\n",
"The goal of this notebook is to provide an understanding of:\n",
"* `torch.nn.functional.cross_entropy`\n",
"* `torch.nn.functional.nll_loss`\n",
"* `torch.nn.functional.bce_loss`\n",
"* `torch.nn.functional.kl_div`\n",
"\n",
"References:\n",
"- Information Theory Stuff: https://www.youtube.com/watch?v=ErfnhcEV1O8\n",
"- PyTorch Docs for Loss Functions: https://pytorch.org/docs/stable/nn.html#loss-functions"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example Distributions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1500x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"distributions = {\n",
" \"normal_dist\": np.array([0.1, 0.2, 0.4, 0.2, 0.1]),\n",
" \"uniform_dist\": np.array([0.2, 0.2, 0.2, 0.2, 0.2]),\n",
" \"random_dist\": np.array([0.4, 0.1, 0.05, 0.20, 0.25]),\n",
"}\n",
"possible_events = list(range(5))\n",
"\n",
"fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))\n",
"for i, (name, dist) in enumerate(distributions.items()):\n",
" assert np.allclose(sum(dist), 1.0)\n",
" assert len(dist) == 5\n",
" axes[i].set_title(name)\n",
" axes[i].set_ylim(bottom=0, top=1)\n",
" axes[i].bar(x=possible_events, height=dist)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Information\n",
"Information can be interpreted as a measure for the rarity of an event (the reduction of uncertainty when that event is transmitted):\n",
"- High information for low probability events\n",
"- Low information for high probability events\n",
"$$\n",
"h(x) = -\\log P(x)\n",
"$$\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal_dist\n",
"h(0) = 2.3025850929940455\n",
"h(1) = 1.6094379124341003\n",
"h(2) = 0.916290731874155\n",
"h(3) = 1.6094379124341003\n",
"h(4) = 2.3025850929940455\n",
"\n",
"uniform_dist\n",
"h(0) = 1.6094379124341003\n",
"h(1) = 1.6094379124341003\n",
"h(2) = 1.6094379124341003\n",
"h(3) = 1.6094379124341003\n",
"h(4) = 1.6094379124341003\n",
"\n",
"random_dist\n",
"h(0) = 0.916290731874155\n",
"h(1) = 2.3025850929940455\n",
"h(2) = 2.995732273553991\n",
"h(3) = 1.6094379124341003\n",
"h(4) = 1.3862943611198906\n",
"\n"
]
}
],
"source": [
"h = lambda p, x: -np.log(p[x])\n",
"\n",
"for name, dist in distributions.items():\n",
" print(name)\n",
" for x in possible_events:\n",
" print(f\"h({x}) = {h(dist, x)}\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Entropy\n",
"Entropy can be interpreted as a measure of uncertainty of a distribution (the average amount of information from sampling the distribution once):\n",
"\n",
"- High entropy if there is lots of uncertainty (uniform distribution)\n",
"- Low entropy if there is little uncertainty (skewed distribution)\n",
"$$\n",
"H(P) = - \\sum P(x) * \\log(P(x))\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal_dist = 1.4708084763221112\n",
"uniform_dist = 1.6094379124341005\n",
"random_dist = 1.415022588493559\n"
]
}
],
"source": [
"H = lambda p: -sum([p_x * np.log(p_x)for p_x in p])\n",
"\n",
"for name, P in distributions.items():\n",
" print(f\"{name} = {H(P)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cross Entropy\n",
"\n",
"From an information theory perspective, cross entropy can be interpreted as the average amount of information we get about the true distribution $P$ when sampling from an estimated distribution $Q$. It looks very similar to the entropy, but we just replace $\\log(P(x))$ with $\\log(Q(x))$:\n",
"$$\n",
"H(P, Q) = -\\sum P(x) * (\\log(Q(x)))\n",
"$$\n",
"\n",
"Notice if $P = Q$, then cross entropy is equal to entropy ($H(P, Q) = H(P)$). One useful property of cross entropy is that as $P$ and $Q$ diverge (become less alike), cross entropy grows ($H(P, Q) > H(P)$). As such, cross entropy can _sort of_ be interpreted as a measure of similarity between two distributions. This becomes more clear when it's written as:\n",
"$$\n",
"H(P, Q) = D_{KL}(P || Q) + H(P)\n",
"$$\n",
"\n",
"where $D_{KL}(P || Q)$ is the _relative entropy_ (also called Kullback–Leibler (KL) divergence) between $P$ and $Q$."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"crossH(normal_dist, normal_dist) = 1.4708084763221112\n",
"crossH(normal_dist, uniform_dist) = 1.6094379124341003\n",
"crossH(normal_dist, random_dist) = 2.21095601980663\n",
"crossH(uniform_dist, normal_dist) = 1.7480673485460896\n",
"crossH(uniform_dist, uniform_dist) = 1.6094379124341005\n",
"crossH(uniform_dist, random_dist) = 1.8420680743952367\n",
"crossH(random_dist, normal_dist) = 2.0253262207700673\n",
"crossH(random_dist, uniform_dist) = 1.6094379124341005\n",
"crossH(random_dist, random_dist) = 1.415022588493559\n"
]
}
],
"source": [
"crossH = lambda p, q: -sum([p_x * np.log(q_x) for p_x, q_x in zip(p, q)])\n",
"\n",
"import itertools\n",
"for (name1, p), (name2, q) in itertools.product(distributions.items(), repeat=2):\n",
" print(f\"crossH({name1}, {name2}) = {crossH(p, q)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### KL Divergence\n",
"KL Divergence can be interpreted as a measure of similarity between two distributions (how much extra information do I need to represent Q given P):\n",
"- High if the two distributions are dissimilar\n",
"- Low if the two distributions are similar\n",
"$$\n",
"D_{KL}(P || Q) = \\sum P(x) * (\\log(P(x)) - \\log(Q(x)))\n",
"$$\n",
"Sometimes also written as:\n",
"$$\n",
"\\sum P(x) * \\log(\\frac{P(x)}{Q(x)})\n",
"$$\n",
"or\n",
"$$\n",
"-\\sum P(x) * \\log(\\frac{Q(x)}{P(x)})\n",
"$$\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"KLD(normal_dist, normal_dist) = 0.0\n",
"KLD(normal_dist, uniform_dist) = 0.13862943611198902\n",
"KLD(normal_dist, random_dist) = 0.7401475434845188\n",
"KLD(uniform_dist, normal_dist) = 0.13862943611198905\n",
"KLD(uniform_dist, uniform_dist) = 0.0\n",
"KLD(uniform_dist, random_dist) = 0.23263016196113617\n",
"KLD(random_dist, normal_dist) = 0.6103036322765085\n",
"KLD(random_dist, uniform_dist) = 0.19441532394054145\n",
"KLD(random_dist, random_dist) = 0.0\n"
]
}
],
"source": [
"KLD = lambda p, q: sum([p_x * (np.log(p_x) - np.log(q_x)) for p_x, q_x in zip(p, q)])\n",
"\n",
"import itertools\n",
"for (name1, p), (name2, q) in itertools.product(distributions.items(), repeat=2):\n",
" print(f\"KLD({name1}, {name2}) = {KLD(p, q)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loss Functions\n",
"$$\n",
"\\text{CCELoss}(\\bm{p}, y) = -\\log(\\bm{p}_y)\n",
"$$\n",
"where $\\bm{p}$ is the predicted probability distribution vector and $y$ is the index of the correct class.\n",
"\n",
"$$\n",
"\\text{BCELoss}(p, y) = -\\log(1 - p)(1 - y) - \\log(p)y\n",
"$$\n",
"where $p$ is the predicted probability and $y$ is the label, either 1 or 0.\n",
"\n",
"$$\n",
"\\text{KLDivLoss}(\\bm{p}, \\bm{y}) = \\sum_{i} \\bm{y}_i * (\\log(\\bm{y}_i) - \\log(\\bm{p}_i))\n",
"$$\n",
"where $\\bm{p}$ is the predicted probability distribution vector and $\\bm{y}$ is the true probability distribution vector"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def log(x, eps=1e-12):\n",
" # safe version of log where log(0) -> very large negative number rather than -inf,\n",
" # which causes numerical instability\n",
" return np.log(x + eps)\n",
"\n",
"def cce_loss(p, y):\n",
" return -log(p[y])\n",
"\n",
"def bce_loss(p, y):\n",
" return -log(1 - p)*(1 - y) - log(p)*y\n",
"\n",
"def kld_loss(p, y):\n",
" return np.sum(y * (log(y) - log(p)))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"p = 0.0\n",
"bce_loss(p, 0) = -1.000088900581841e-12\n",
"bce_loss(p, 1) = 27.631021115928547\n",
"\n",
"p = 0.1\n",
"bce_loss(p, 0) = 0.10536051565671518\n",
"bce_loss(p, 1) = 2.3025850929840455\n",
"\n",
"p = 0.2\n",
"bce_loss(p, 0) = 0.22314355131295974\n",
"bce_loss(p, 1) = 1.6094379124291003\n",
"\n",
"p = 0.3\n",
"bce_loss(p, 0) = 0.3566749439373039\n",
"bce_loss(p, 1) = 1.2039728043226028\n",
"\n",
"p = 0.4\n",
"bce_loss(p, 0) = 0.510825623764324\n",
"bce_loss(p, 1) = 0.9162907318716551\n",
"\n",
"p = 0.5\n",
"bce_loss(p, 0) = 0.6931471805579453\n",
"bce_loss(p, 1) = 0.6931471805579453\n",
"\n",
"p = 0.6\n",
"bce_loss(p, 0) = 0.9162907318716551\n",
"bce_loss(p, 1) = 0.510825623764324\n",
"\n",
"p = 0.7\n",
"bce_loss(p, 0) = 1.2039728043226026\n",
"bce_loss(p, 1) = 0.3566749439373039\n",
"\n",
"p = 0.8\n",
"bce_loss(p, 0) = 1.6094379124291005\n",
"bce_loss(p, 1) = 0.22314355131295974\n",
"\n",
"p = 0.9\n",
"bce_loss(p, 0) = 2.302585092984046\n",
"bce_loss(p, 1) = 0.10536051565671518\n",
"\n",
"p = 1.0\n",
"bce_loss(p, 0) = 27.631021115928547\n",
"bce_loss(p, 1) = -1.000088900581841e-12\n",
"\n"
]
}
],
"source": [
"for p in np.arange(11) / 10:\n",
" print(f\"p = {p}\")\n",
" print(\"bce_loss(p, 0) =\", bce_loss(p, 0))\n",
" print(\"bce_loss(p, 1) =\", bce_loss(p, 1))\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--- normal_dist ---\n",
"target = 0\n",
"cce_loss = 2.3025850929840455\n",
"kld_loss = 2.3025850929850455\n",
"\n",
"target = 1\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"target = 2\n",
"cce_loss = 0.9162907318716551\n",
"kld_loss = 0.9162907318726552\n",
"\n",
"target = 3\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"target = 4\n",
"cce_loss = 2.3025850929840455\n",
"kld_loss = 2.3025850929850455\n",
"\n",
"\n",
"--- uniform_dist ---\n",
"target = 0\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"target = 1\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"target = 2\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"target = 3\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"target = 4\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"\n",
"--- random_dist ---\n",
"target = 0\n",
"cce_loss = 0.9162907318716551\n",
"kld_loss = 0.9162907318726552\n",
"\n",
"target = 1\n",
"cce_loss = 2.3025850929840455\n",
"kld_loss = 2.3025850929850455\n",
"\n",
"target = 2\n",
"cce_loss = 2.995732273533991\n",
"kld_loss = 2.995732273534991\n",
"\n",
"target = 3\n",
"cce_loss = 1.6094379124291003\n",
"kld_loss = 1.6094379124301004\n",
"\n",
"target = 4\n",
"cce_loss = 1.3862943611158907\n",
"kld_loss = 1.3862943611168907\n",
"\n",
"\n"
]
}
],
"source": [
"# notice, cce_loss and kld_loss are interchangeable for the categorical case\n",
"for name, dist in distributions.items():\n",
" print(f\"--- {name} ---\")\n",
" for target in range(len(dist)):\n",
" target_one_hot = np.zeros(len(dist))\n",
" target_one_hot[target] = 1\n",
" print(f\"target = {target}\")\n",
" print(\n",
" f\"cce_loss = {cce_loss(dist, target)}\",\n",
" f\"kld_loss = {kld_loss(dist, target_one_hot)}\",\n",
" sep=\"\\n\",\n",
" )\n",
" print()\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### In Practice\n",
"> What do we do about $0 * \\log(0)$?\n",
"\n",
"We treat the output as 0, since $\\lim_{x \\rightarrow 0} x * \\log(x) = 0$\n",
"\n",
"> What do we do about $\\log(0)$?\n",
"\n",
"We replace $0$ with $\\epsilon$, which we can choose to set to a very small number, say $10^{-12}$.\n",
"\n",
"> How do you calculate the loss for the entire dataset, not just one example?\n",
"\n",
"We just take the mean of the per loss examples. For example, our equation for BCELoss becomes:\n",
"$$\n",
"\\text{BCELoss}(\\bm{p}, \\bm{y}) = \\frac{1}{N}\\sum_i^N-\\log(1 - \\bm{p}_i)(1 - \\bm{y}_i) - \\log(\\bm{p}_i)\\bm{y}_i\n",
"$$\n",
"\n",
"> How do we account for numerical instability and floating point errors?\n",
"\n",
"Neural networks output unnormalized probabilities (which are often called logits). To convert the logits to probabilities, we use the softmax function.\n",
"\n",
"$$\n",
"\\text{softmax}(\\bm{x})_i = \\frac{e^{\\bm{x}_i}}{\\sum_j e^{\\bm{x}_j}}\n",
"$$\n",
"\n",
"The above formulation of softmax is very numerical instable due to the value of the exponents being really large (and therefore may cause floating point errors). To preven this, we can take advantage of some properties of exponents:\n",
"\n",
"$$\n",
"\\begin{align}\n",
"\\text{softmax}(\\bm{x})_i\n",
"& = \\frac{e^{\\bm{x}_i}}{\\sum_j e^{\\bm{x}_j}} \\\\\n",
"& = \\frac{C}{C}\\frac{e^{\\bm{x}_i}}{\\sum_j e^{\\bm{x}_j}} \\\\\n",
"& = \\frac{Ce^{\\bm{x}_i}}{\\sum_j Ce^{\\bm{x}_j}} \\\\\n",
"& = \\frac{e^{\\bm{x}_i + \\log(C)}}{\\sum_j e^{\\bm{x}_j + \\log(C)}} \\\\\n",
"\\end{align}\n",
"$$\n",
"\n",
"If we set $\\log(C) = -\\max(\\bm{x})$, we can control our exponentiated values to be between 0 and 1, reducing numerical instability. However, even with this trick, taking the log after softmax is still quite numerically instable. However, if we combine the two together directly and do some manipulation:\n",
"\n",
"$$\n",
"\\begin{align}\n",
"\\text{log\\_softmax}(\\bm{x})_i\n",
"= & \\log(\\frac{e^{\\bm{x}_i + \\log(C)}}{\\sum_j e^{\\bm{x}_j + \\log(C)}}) \\\\\n",
"= & \\log(e^{\\bm{x}_i + \\log(C)}) - \\log(\\sum_j e^{\\bm{x}_j + \\log(C)}) \\\\\n",
"= & (\\bm{x}_i + \\log(C))\\log(e) - \\log(\\sum_j e^{\\bm{x}_j + \\log(C)}) \\\\\n",
"= & \\bm{x}_i + \\log(C) - \\log(\\sum_j e^{\\bm{x}_j + \\log(C)}) \\\\\n",
"= & \\bm{x}_i - \\max(\\bm{x}) - \\log(\\sum_j e^{\\bm{x}_j - \\max(\\bm{x})})\n",
"\\end{align}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def softmax_unstable(x):\n",
" return np.exp(x) / np.sum(np.exp(x))\n",
"\n",
"def softmax(x):\n",
" x_max = np.max(x)\n",
" return np.exp(x - x_max) / np.sum(np.exp(x - x_max))\n",
"\n",
"def log_softmax(x):\n",
" x_max = np.max(x)\n",
" return x - x_max - np.log(np.sum(np.exp(x - x_max)))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"softmax_unstable(x): [0.61489897 0.23445245 0.09552892 0.05511966]\n",
"\n",
"softmax(x): [0.61489897 0.23445245 0.09552892 0.05511966]\n",
"\n",
"log(softmax(x)): [-0.48629729 -1.45050249 -2.34832625 -2.89824887]\n",
"\n",
"log_softmax(x): [-0.48629729 -1.45050249 -2.34832625 -2.89824887]\n"
]
}
],
"source": [
"x = np.random.normal(size=(4))\n",
"print(\n",
" f\"softmax_unstable(x): {softmax_unstable(x)}\",\n",
" f\"softmax(x): {softmax(x)}\",\n",
" f\"log(softmax(x)): {np.log(softmax(x))}\",\n",
" f\"log_softmax(x): {log_softmax(x)}\",\n",
" sep=\"\\n\\n\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"softmax_unstable(x): [nan 0. 0. nan]\n",
"\n",
"softmax(x): [1. 0. 0. 0.]\n",
"\n",
"log(softmax(x)): [ 0. -inf -inf -inf]\n",
"\n",
"log_softmax(x): [ 0. -5537.00322633 -7592.30679862 -1077.17353227]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/51/h90b7_jx6sz6vphspgd1921h0000gp/T/ipykernel_68202/809860734.py:4: RuntimeWarning: overflow encountered in exp\n",
" return np.exp(x) / np.sum(np.exp(x))\n",
"/var/folders/51/h90b7_jx6sz6vphspgd1921h0000gp/T/ipykernel_68202/809860734.py:4: RuntimeWarning: invalid value encountered in divide\n",
" return np.exp(x) / np.sum(np.exp(x))\n",
"/var/folders/51/h90b7_jx6sz6vphspgd1921h0000gp/T/ipykernel_68202/993447327.py:5: RuntimeWarning: divide by zero encountered in log\n",
" f\"log(softmax(x)): {np.log(softmax(x))}\",\n"
]
}
],
"source": [
"x = np.random.normal(size=(4)) * 10000\n",
"print(\n",
" f\"softmax_unstable(x): {softmax_unstable(x)}\",\n",
" f\"softmax(x): {softmax(x)}\",\n",
" f\"log(softmax(x)): {np.log(softmax(x))}\",\n",
" f\"log_softmax(x): {log_softmax(x)}\",\n",
" sep=\"\\n\\n\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Torch Loss Functions"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.nn.functional import cross_entropy, kl_div, binary_cross_entropy, nll_loss"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def apply_over_batch_and_cast_to_numpy(f):\n",
" return lambda X, Y: sum(f(np.array(x), np.array(y)) for x, y in zip(X, Y)) / len(X)\n",
"\n",
"batch_cce_loss = apply_over_batch_and_cast_to_numpy(cce_loss)\n",
"batch_bce_loss = apply_over_batch_and_cast_to_numpy(bce_loss)\n",
"batch_kld_loss = apply_over_batch_and_cast_to_numpy(kld_loss)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Binary Classification"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch: 1.1347017288208008\n",
"numpy: 1.1347016079713899\n"
]
}
],
"source": [
"# comparing binary cross entropy\n",
"batch_size = 64\n",
"\n",
"p = torch.rand(size=(batch_size,))\n",
"y = torch.randint(low=0, high=2, size=(batch_size,)).float()\n",
"\n",
"print(\n",
" f\"torch: {binary_cross_entropy(p, y).item()}\",\n",
" f\"numpy: {batch_bce_loss(p, y)}\",\n",
" sep=\"\\n\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Multi-class Classification\n",
"Notice:\n",
"* All three of `nll_loss`, `cross_entropy` and `kl_div` have the same result.\n",
"* This is because `cross_entropy(p, q) = kl_div(p, q) + entropy(p)`, but for multi-class classification `entropy(p) = 0` (since `p` is a one-hot vector)\n",
"* `torch` does not allow you to compute cross entropy loss using probabilities, instead:\n",
" * Use `cross_entropy` on the logits (raw unnormalized output)\n",
" * Use `nll_loss` on the log probabilities (i.e. on the output of `log_softmax(logits)`)\n",
" * Use `kl_div` on the log probabilities (same as above), however, you'll need to convert `y` to a one hot encoded vector\n",
"* The reason torch doesn't let you use probabilities directly is to prevent `log(softmax(logits))` and instead use `log_softmax(logits)` which is a lot more stable."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--- torch functions ---\n",
"cross_entropy: 2.743520498275757\n",
"nll_loss: 2.743520498275757\n",
"kl_div: 2.743520498275757\n",
"--- numpy functions ---\n",
"batch_cce_loss: 2.74352032315047\n",
"batch_kld_loss: 2.743520326912403\n"
]
}
],
"source": [
"# comparing nll_loss, cross_entropy, kl_div for categorical data\n",
"batch_size = 64\n",
"n_classes = 10\n",
"\n",
"logits = torch.randn(size=(batch_size, n_classes))\n",
"p = torch.softmax(logits, axis=-1)\n",
"log_p = torch.log_softmax(logits, axis=-1)\n",
"\n",
"y = torch.randint(low=0, high=n_classes, size=(batch_size,))\n",
"y_one_hot = torch.eye(n_classes)[y]\n",
"\n",
"print(\n",
" \"--- torch functions ---\",\n",
" f\"cross_entropy: {cross_entropy(logits, y).item()}\",\n",
" f\"nll_loss: {nll_loss(log_p, y).item()}\",\n",
" f\"kl_div: {kl_div(log_p, y_one_hot, reduction='batchmean').item()}\",\n",
" \"--- numpy functions ---\",\n",
" f\"batch_cce_loss: {batch_cce_loss(p, y)}\",\n",
" f\"batch_kld_loss: {batch_kld_loss(p, y_one_hot)}\",\n",
" sep=\"\\n\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Fitting on Probability Distributions\n",
"Cross entropy loss only allows you to fit on probability distributions where all the \"mass\" is on one event (i.e. a one hot vector). If your network needs to learn _any_ distribution, then you'll need to use `kl_div` loss. In this case, the output of your network and the labels are both prob distributions."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch: 0.8245439529418945\n",
"numpy: 0.8245439636521041\n"
]
}
],
"source": [
"# comparing kl_div on distribution targets\n",
"batch_size = 64\n",
"n_classes = 10\n",
"\n",
"logits = torch.randn(size=(batch_size, n_classes))\n",
"p = torch.softmax(logits, axis=-1)\n",
"log_p = torch.log_softmax(logits, axis=-1)\n",
"\n",
"y = torch.softmax(torch.randn(size=(batch_size, n_classes)), axis=-1)\n",
"\n",
"print(\n",
" f\"torch: {kl_div(log_p, y, reduction='batchmean').item()}\",\n",
" f\"numpy: {batch_kld_loss(p, y)}\",\n",
" sep=\"\\n\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.10 64-bit ('3.9.10')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fd09b19eb83f586d348350b5c89c7a987a0d039b02a538583d56ff9c88f80cb0"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment