Skip to content

Instantly share code, notes, and snippets.

@yang-zhang
Created October 16, 2018 20:46
Show Gist options
  • Save yang-zhang/09460d9e90a1bf29fb6edf121865df86 to your computer and use it in GitHub Desktop.
Save yang-zhang/09460d9e90a1bf29fb6edf121865df86 to your computer and use it in GitHub Desktop.
binary cross entropy implementation in pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "This notebook breaks down how `binary_cross_entropy_with_logits` function (corresponding to `BCEWithLogitsLoss` used for multilabel classification) is implemented in pytorch, and how it is related to `sigmoid` and `binary_cross_entropy`"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F",
"execution_count": 82,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "batch_size, n_classes = 10, 4\nx = torch.randn(batch_size, n_classes)\nx.shape",
"execution_count": 83,
"outputs": [
{
"data": {
"text/plain": "torch.Size([10, 4])"
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x",
"execution_count": 84,
"outputs": [
{
"data": {
"text/plain": "tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],\n [ 0.0419, 0.0763, -1.0457, -1.6692],\n [-1.0494, 0.8111, 1.5723, 1.2315],\n [ 1.3081, 0.6641, 1.1802, -0.2547],\n [ 0.5292, 0.7636, 0.3692, -0.8318],\n [ 0.5100, 0.9849, -1.2905, 0.2821],\n [ 1.4662, 0.4550, 0.9875, 0.3143],\n [-1.2121, 0.1262, 0.0598, -1.6363],\n [ 0.3214, -0.8689, 0.0689, -2.5094],\n [ 1.1320, -0.6824, 0.1657, -0.0687]])"
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "target = torch.randint(n_classes, size=(batch_size,), dtype=torch.long)\ntarget",
"execution_count": 85,
"outputs": [
{
"data": {
"text/plain": "tensor([1, 1, 3, 0, 2, 0, 2, 2, 1, 2])"
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "y = torch.zeros(batch_size, n_classes)\ny[range(y.shape[0]), target]=1\ny",
"execution_count": 86,
"outputs": [
{
"data": {
"text/plain": "tensor([[0., 1., 0., 0.],\n [0., 1., 0., 0.],\n [0., 0., 0., 1.],\n [1., 0., 0., 0.],\n [0., 0., 1., 0.],\n [1., 0., 0., 0.],\n [0., 0., 1., 0.],\n [0., 0., 1., 0.],\n [0., 1., 0., 0.],\n [0., 0., 1., 0.]])"
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### `sigmoid` + `binary_cross_entropy`"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def sigmoid(x): return (1 + (-x).exp()).reciprocal()\ndef binary_cross_entropy(input, y): return -(pred.log()*y + (1-y)*(1-pred).log()).mean()\n\npred = sigmoid(x)\nloss = binary_cross_entropy(pred, y)\nloss",
"execution_count": 87,
"outputs": [
{
"data": {
"text/plain": "tensor(0.7739)"
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### `F.sigmoid` + `F.binary_cross_entropy`"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The above but in pytorch."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "pred = torch.sigmoid(x)\nloss = F.binary_cross_entropy(pred, y)\nloss",
"execution_count": 88,
"outputs": [
{
"data": {
"text/plain": "tensor(0.7739)"
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### `F.binary_cross_entropy_with_logits`"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Pytorch's single `binary_cross_entropy_with_logits` function."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "F.binary_cross_entropy_with_logits(x, y)",
"execution_count": 89,
"outputs": [
{
"data": {
"text/plain": "tensor(0.7739)"
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/e7f3ef44c16c3cdef2cb59c008d3e86c"
},
"gist": {
"id": "e7f3ef44c16c3cdef2cb59c008d3e86c",
"data": {
"description": "binary cross entropy implementation in pytorch",
"public": true
}
},
"kernelspec": {
"name": "conda-env-fastaiv1-py",
"display_name": "Python [conda env:fastaiv1]",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.0",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"base_numbering": 1,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment