Skip to content

Instantly share code, notes, and snippets.

@manuelmazzuola
Created March 8, 2021 15:59
Show Gist options
  • Save manuelmazzuola/d8dc5c346b790aaab915be6cd10057ef to your computer and use it in GitHub Desktop.
Save manuelmazzuola/d8dc5c346b790aaab915be6cd10057ef to your computer and use it in GitHub Desktop.
softmax-with-temperature.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "softmax-with-temperature.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyPGDSyZmSAEY/pciQMg3Vwp",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/manuelmazzuola/d8dc5c346b790aaab915be6cd10057ef/softmax-with-temperature.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "v5kt4hbKb80o"
},
"source": [
"import torch"
],
"execution_count": 76,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aee3_Y50a0-j"
},
"source": [
"def softmax(input, t=1.0):\n",
" print(\"input\", input)\n",
" ex = torch.exp(input/t)\n",
" print(\"exp\", ex)\n",
" sum = torch.sum(ex, axis=0)\n",
" return ex / sum\n",
"\n",
"def cross_entropy(distribution):\n",
" target = torch.tensor([0, 0, 1, 0, 0])\n",
" print(\"loss\", -torch.sum(target * torch.log(distribution)))\n"
],
"execution_count": 122,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rxBP12EAbhkk",
"outputId": "bce850c7-7d88-477e-9c85-b8817fe3aac8"
},
"source": [
"input = torch.tensor([55.8906, -114.5621, 6.3440, -30.2473, -44.1440])\n",
"cross_entropy(softmax(input))"
],
"execution_count": 123,
"outputs": [
{
"output_type": "stream",
"text": [
"input tensor([ 55.8906, -114.5621, 6.3440, -30.2473, -44.1440])\n",
"exp tensor([1.8749e+24, 0.0000e+00, 5.6907e+02, 7.3074e-14, 6.7376e-20])\n",
"loss tensor(nan)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8K3ZnHFRe_JP",
"outputId": "f74b9559-5cd5-4639-82ea-ed89471d48ad"
},
"source": [
"input = torch.tensor([55.8906, -114.5621, 6.3440, -30.2473, -44.1440])\n",
"cross_entropy(softmax(input, t=10))"
],
"execution_count": 124,
"outputs": [
{
"output_type": "stream",
"text": [
"input tensor([ 55.8906, -114.5621, 6.3440, -30.2473, -44.1440])\n",
"exp tensor([2.6748e+02, 1.0584e-05, 1.8859e+00, 4.8571e-02, 1.2102e-02])\n",
"loss tensor(4.9619)\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment