Skip to content

Instantly share code, notes, and snippets.

@yang-zhang
Last active March 18, 2024 06:05
Show Gist options
  • Star 17 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save yang-zhang/217dcc6ae9171d7a46ce42e215c1fee0 to your computer and use it in GitHub Desktop.
Save yang-zhang/217dcc6ae9171d7a46ce42e215c1fee0 to your computer and use it in GitHub Desktop.
Cross entropy implementation in pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "This notebook breaks down how `cross_entropy` function (corresponding to `CrossEntropyLoss` used for classification) is implemented in pytorch, and how it is related to softmax, log_softmax, and nll (negative log-likelihood)."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F",
"execution_count": 82,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "batch_size, n_classes = 5, 3\nx = torch.randn(batch_size, n_classes)\nx.shape",
"execution_count": 83,
"outputs": [
{
"data": {
"text/plain": "torch.Size([5, 3])"
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "x",
"execution_count": 84,
"outputs": [
{
"data": {
"text/plain": "tensor([[ 0.9826, 1.0630, -0.4096],\n [-0.6213, 0.2511, 0.5659],\n [ 0.5662, 0.7360, -0.6783],\n [-0.4638, -1.4961, -1.0877],\n [ 1.8186, -0.2998, 0.1128]])"
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "target = torch.randint(n_classes, size=(batch_size,), dtype=torch.long)\ntarget",
"execution_count": 85,
"outputs": [
{
"data": {
"text/plain": "tensor([1, 0, 1, 1, 1])"
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### `softmax` + `nl` (negative likelihood)"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "This version is most similar to the math formula, but not numerically stable."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def softmax(x): return x.exp() / (x.exp().sum(-1)).unsqueeze(-1)\ndef nl(input, target): return -input[range(target.shape[0]), target].log().mean()\n\npred = softmax(x)\nloss=nl(pred, target)\nloss",
"execution_count": 86,
"outputs": [
{
"data": {
"text/plain": "tensor(1.4904)"
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "pred = softmax(x)\nloss=nl(pred, target)\nloss",
"execution_count": 87,
"outputs": [
{
"data": {
"text/plain": "tensor(1.4904)"
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### `log_softmax` + `nll` (negative log-likelihood)"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "https://pytorch.org/docs/stable/nn.html?highlight=logsoftmax#torch-nn-functional\n>While mathematically equivalent to `log(softmax(x))`, doing these two operations separately is slower, and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def log_softmax(x): return x - x.exp().sum(-1).log().unsqueeze(-1)\ndef nll(input, target): return -input[range(target.shape[0]), target].mean()\n\npred = log_softmax(x)\nloss = nll(pred, target)\nloss",
"execution_count": 88,
"outputs": [
{
"data": {
"text/plain": "tensor(1.4904)"
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### `F.log_softmax` + `F.nll_loss`"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The above but in pytorch."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "pred = F.log_softmax(x, dim=-1)\nloss = F.nll_loss(pred, target)\nloss",
"execution_count": 89,
"outputs": [
{
"data": {
"text/plain": "tensor(1.4904)"
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### `F.cross_entropy`"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Pytorch's single cross_entropy function."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "F.cross_entropy(x, target)",
"execution_count": 90,
"outputs": [
{
"data": {
"text/plain": "tensor(1.4904)"
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Reference:\n- https://github.com/fastai/fastai_old"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/217dcc6ae9171d7a46ce42e215c1fee0"
},
"gist": {
"id": "217dcc6ae9171d7a46ce42e215c1fee0",
"data": {
"description": "Cross entropy implementation in pytorch",
"public": true
}
},
"kernelspec": {
"name": "conda-env-fastaiv1-py",
"display_name": "Python [conda env:fastaiv1]",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.0",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"base_numbering": 1,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment