Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save myurasov/d96f932091ac0660982356770fd81ae6 to your computer and use it in GitHub Desktop.
Save myurasov/d96f932091ac0660982356770fd81ae6 to your computer and use it in GitHub Desktop.
Training with categorical_crossentropy vs sparse_categorical_crossentropy
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "### Training with `categorical_crossentropy` vs `sparse_categorical_crossentropy` "
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import numpy as np\nfrom keras.layers import Input, Dense, Flatten\nfrom keras.models import Model",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": "Using TensorFlow backend.\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "i = Input(shape=(28, 28, 1), name='input_image')\nx = Flatten()(i)\no1 = Dense(1, activation='sigmoid', name='output_1')(x)\no2 = Dense(10, activation='sigmoid', name='output_2')(x)",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "m = Model(inputs=[i], outputs=[o1, o2])",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "def model_as_svg(model):\n from IPython.display import SVG\n from keras.utils.vis_utils import model_to_dot\n return SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "model_as_svg(m)",
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 5,
"data": {
"text/plain": "<IPython.core.display.SVG object>",
"image/svg+xml": "<svg height=\"221pt\" viewBox=\"0.00 0.00 514.00 221.00\" width=\"514pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 217)\">\n<title>G</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-217 510,-217 510,4 -4,4\" stroke=\"none\"/>\n<!-- 139690491401608 -->\n<g class=\"node\" id=\"node1\"><title>139690491401608</title>\n<polygon fill=\"none\" points=\"94,-166.5 94,-212.5 412,-212.5 412,-166.5 94,-166.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"170\" y=\"-185.8\">input_image: InputLayer</text>\n<polyline fill=\"none\" points=\"246,-166.5 246,-212.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"273.5\" y=\"-197.3\">input:</text>\n<polyline fill=\"none\" points=\"246,-189.5 301,-189.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"273.5\" y=\"-174.3\">output:</text>\n<polyline fill=\"none\" points=\"301,-166.5 301,-212.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"356.5\" y=\"-197.3\">(None, 28, 28, 1)</text>\n<polyline fill=\"none\" points=\"301,-189.5 412,-189.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"356.5\" y=\"-174.3\">(None, 28, 28, 1)</text>\n</g>\n<!-- 139690491401552 -->\n<g class=\"node\" id=\"node2\"><title>139690491401552</title>\n<polygon fill=\"none\" points=\"115,-83.5 115,-129.5 391,-129.5 391,-83.5 115,-83.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"170\" y=\"-102.8\">flatten_1: Flatten</text>\n<polyline fill=\"none\" points=\"225,-83.5 225,-129.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"252.5\" y=\"-114.3\">input:</text>\n<polyline fill=\"none\" points=\"225,-106.5 280,-106.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"252.5\" y=\"-91.3\">output:</text>\n<polyline fill=\"none\" points=\"280,-83.5 280,-129.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"335.5\" y=\"-114.3\">(None, 28, 28, 1)</text>\n<polyline fill=\"none\" points=\"280,-106.5 391,-106.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"335.5\" y=\"-91.3\">(None, 784)</text>\n</g>\n<!-- 139690491401608&#45;&gt;139690491401552 -->\n<g class=\"edge\" id=\"edge1\"><title>139690491401608-&gt;139690491401552</title>\n<path d=\"M253,-166.366C253,-158.152 253,-148.658 253,-139.725\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"256.5,-139.607 253,-129.607 249.5,-139.607 256.5,-139.607\" stroke=\"black\"/>\n</g>\n<!-- 139690491401664 -->\n<g class=\"node\" id=\"node3\"><title>139690491401664</title>\n<polygon fill=\"none\" points=\"0,-0.5 0,-46.5 244,-46.5 244,-0.5 0,-0.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"53\" y=\"-19.8\">output_1: Dense</text>\n<polyline fill=\"none\" points=\"106,-0.5 106,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"133.5\" y=\"-31.3\">input:</text>\n<polyline fill=\"none\" points=\"106,-23.5 161,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"133.5\" y=\"-8.3\">output:</text>\n<polyline fill=\"none\" points=\"161,-0.5 161,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"202.5\" y=\"-31.3\">(None, 784)</text>\n<polyline fill=\"none\" points=\"161,-23.5 244,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"202.5\" y=\"-8.3\">(None, 1)</text>\n</g>\n<!-- 139690491401552&#45;&gt;139690491401664 -->\n<g class=\"edge\" id=\"edge2\"><title>139690491401552-&gt;139690491401664</title>\n<path d=\"M217.204,-83.3664C201.608,-73.723 183.161,-62.3171 166.678,-52.1252\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"168.099,-48.8891 157.753,-46.6068 164.418,-54.8429 168.099,-48.8891\" stroke=\"black\"/>\n</g>\n<!-- 139690491402392 -->\n<g class=\"node\" id=\"node4\"><title>139690491402392</title>\n<polygon fill=\"none\" points=\"262,-0.5 262,-46.5 506,-46.5 506,-0.5 262,-0.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"315\" y=\"-19.8\">output_2: Dense</text>\n<polyline fill=\"none\" points=\"368,-0.5 368,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"395.5\" y=\"-31.3\">input:</text>\n<polyline fill=\"none\" points=\"368,-23.5 423,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"395.5\" y=\"-8.3\">output:</text>\n<polyline fill=\"none\" points=\"423,-0.5 423,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"464.5\" y=\"-31.3\">(None, 784)</text>\n<polyline fill=\"none\" points=\"423,-23.5 506,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"464.5\" y=\"-8.3\">(None, 10)</text>\n</g>\n<!-- 139690491401552&#45;&gt;139690491402392 -->\n<g class=\"edge\" id=\"edge3\"><title>139690491401552-&gt;139690491402392</title>\n<path d=\"M288.796,-83.3664C304.392,-73.723 322.839,-62.3171 339.322,-52.1252\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"341.582,-54.8429 348.247,-46.6068 337.901,-48.8891 341.582,-54.8429\" stroke=\"black\"/>\n</g>\n</g>\n</svg>"
},
"metadata": {}
}
]
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "m.compile(optimizer='sgd', loss=['binary_crossentropy', 'categorical_crossentropy'])",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "img = np.random.random(size=(1,28,28,1))",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "p = m.predict(img)\np",
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 8,
"data": {
"text/plain": "[array([[ 0.52987033]], dtype=float32),\n array([[ 0.3056376 , 0.38724217, 0.55166078, 0.33891007, 0.2556529 ,\n 0.2712743 , 0.67876822, 0.61177981, 0.76250398, 0.4633559 ]], dtype=float32)]"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "r = m.train_on_batch(img, p)\nr",
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 9,
"data": {
"text/plain": "[11.032898, 0.69136167, 10.341537]"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "assert(len(r) == 3)",
"execution_count": 10,
"outputs": []
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "m.compile(optimizer='sgd', loss=['binary_crossentropy', 'sparse_categorical_crossentropy']) ",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "r = m.train_on_batch(img, [p[0], np.argmax(p[1], axis=1)])\nr",
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 12,
"data": {
"text/plain": "[2.4943717, 0.69136167, 1.80301]"
},
"metadata": {}
}
]
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "assert(len(r) == 3)",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"collapsed": true,
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"file_extension": ".py",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"name": "python",
"version": "3.5.2",
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python"
},
"gist": {
"id": "4cb605271fcb16195189305888ee3311",
"data": {
"description": "Training with categorical_crossentropy vs sparse_categorical_crossentropy",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/4cb605271fcb16195189305888ee3311"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment