Skip to content

Instantly share code, notes, and snippets.

@jerinphilip
Last active July 4, 2018 12:06
Show Gist options
  • Save jerinphilip/01c77c7a5a449fac452b6c44fed30d9f to your computer and use it in GitHub Desktop.
Save jerinphilip/01c77c7a5a449fac452b6c44fed30d9f to your computer and use it in GitHub Desktop.
Temperature and it's effects.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"import torch\n",
"from torch import nn, optim\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Activations\n",
"\n",
"Our activations typically include positive and negative values, we're generating these using `torch.randn`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([-1.0198, -0.0233, -0.6869, 1.4886, 0.6051, -0.2313, 0.7285,\n",
" 0.7995, -0.7978, 0.5780])\n"
]
}
],
"source": [
"n = 10\n",
"acts = torch.randn(n)\n",
"print(acts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Probabilites\n",
"\n",
"I generate probabilities using softmax on the activations."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"probs = F.softmax(acts, dim=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helpers\n",
"\n",
"To see how the distribution changes, matplotlib and corresponding helpers."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADNRJREFUeJzt3X+s3Xddx/Hni5aJDGQmuybaVtbEwmyIZvNmTpcocZh0YNo/JGZNhkom/YciKtEUNdPMf0AM/kgqWgFRROacxDRSrYnMmBi39I4h0taZa5nrLTO7DEQj0dL49o97ao6Xtud713Pvad/3+UiWnO/3fHK+77PdPfvt9/y4qSokSb28aNYDSJKmz7hLUkPGXZIaMu6S1JBxl6SGjLskNTQx7kk+mOS5JJ+5zP1J8htJFpN8Osnt0x9TkrQWQ87cPwTsucL99wC7Rv8cAN539WNJkq7GxLhX1d8AX7jCkn3A79eKx4CbknzjtAaUJK3d1ik8xjbg7Nj20mjfs6sXJjnAytk9N95443fceuutUzi8JG0eTzzxxOeram7SumnEfbCqOgIcAZifn6+FhYWNPLwkXfeS/MuQddN4t8w5YMfY9vbRPknSjEwj7keBHx69a+ZO4EtV9VWXZCRJG2fiZZkkHwVeC9ycZAn4BeDFAFX1W8Ax4PXAIvBl4M3rNawkaZiJca+q/RPuL+CtU5tIknTV/ISqJDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDQ2Ke5I9SZ5Kspjk0CXu/+YkjyZ5Msmnk7x++qNKkoaaGPckW4DDwD3AbmB/kt2rlv088HBV3QbcC/zmtAeVJA035Mz9DmCxqs5U1XngIWDfqjUFfN3o9iuAz01vREnSWg2J+zbg7Nj20mjfuF8E7kuyBBwD3napB0pyIMlCkoXl5eUXMK4kaYhpvaC6H/hQVW0HXg98OMlXPXZVHamq+aqan5ubm9KhJUmrDYn7OWDH2Pb20b5x9wMPA1TV3wEvAW6exoCSpLUbEvcTwK4kO5PcwMoLpkdXrXkGuBsgybeyEnevu0jSjEyMe1VdAA4Cx4HTrLwr5mSSB5PsHS17B/CWJH8PfBT40aqq9RpaknRlW4csqqpjrLxQOr7vgbHbp4C7pjuaJOmF8hOqktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGBsU9yZ4kTyVZTHLoMmt+KMmpJCeT/OF0x5QkrcXWSQuSbAEOA98PLAEnkhytqlNja3YB7wTuqqovJvmG9RpYkjTZkDP3O4DFqjpTVeeBh4B9q9a8BThcVV8EqKrnpjumJGkthsR9G3B2bHtptG/cq4BXJfnbJI8l2XOpB0pyIMlCkoXl5eUXNrEkaaJpvaC6FdgFvBbYD/xOkptWL6qqI1U1X1Xzc3NzUzq0JGm1IXE/B+wY294+2jduCThaVV+pqs8C/8RK7CVJMzAk7ieAXUl2JrkBuBc4umrNn7Jy1k6Sm1m5THNminNKktZgYtyr6gJwEDgOnAYerqqTSR5Msne07DjwfJJTwKPAT1fV8+s1tCTpylJVMznw/Px8LSwszOTYknS9SvJEVc1PWucnVCWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWpoUNyT7EnyVJLFJIeusO4Hk1SS+emNKElaq4lxT7IFOAzcA+wG9ifZfYl1LwfeDjw+7SElSWsz5Mz9DmCxqs5U1XngIWDfJdb9EvBu4L+mOJ8k6QUYEvdtwNmx7aXRvv+T5HZgR1V9/EoPlORAkoUkC8vLy2seVpI0zFW/oJrkRcB7gXdMWltVR6pqvqrm5+bmrvbQkqTLGBL3c8COse3to30XvRx4DfDXSZ4G7gSO+qKqJM3OkLifAHYl2ZnkBuBe4OjFO6vqS1V1c1XdUlW3AI8Be6tqYV0mliRNNDHuVXUBOAgcB04DD1fVySQPJtm73gNKktZu65BFVXUMOLZq3wOXWfvaqx9LknQ1/ISqJDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqaFBv2ZPuuXQx9f9GE+/6w3rfgxps/DMXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJamhQ3JPsSfJUksUkhy5x/08lOZXk00n+Kskrpz+qJGmoiXFPsgU4DNwD7Ab2J9m9atmTwHxVfRvwCPDL0x5UkjTckDP3O4DFqjpTVeeBh4B94wuq6tGq+vJo8zFg+3THlCStxZC4bwPOjm0vjfZdzv3An1/qjiQHkiwkWVheXh4+pSRpTab6gmqS+4B54D2Xur+qjlTVfFXNz83NTfPQkqQxQ35B9jlgx9j29tG+/yfJ64CfA763qv57OuNJkl6IIWfuJ4BdSXYmuQG4Fzg6viDJbcBvA3ur6rnpjylJWouJca+qC8BB4DhwGni4qk4meTDJ3tGy9wAvA/44yaeSHL3Mw0mSNsCQyzJU1THg2Kp9D4zdft2U55IkXQU/oSpJDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIYGvRVS0uZzy6GPr/sxnn7XG9b9GJuVZ+6S1JBxl6SGvCwjXcPW+9KIl0X68sxdkhoy7pLUkJdldF3w8oS0NsZdksZ0eQuol2UkqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ359QNr1OWjyZJ688xdkhoy7pLUkJdlpAn8umFdjzxzl6SGjLskNWTcJakhr7lfR3wbpjYLf9avnmfuktTQdXnm7p/qknRlnrlLUkOD4p5kT5KnkiwmOXSJ+78myR+N7n88yS3THlSSNNzEuCfZAhwG7gF2A/uT7F617H7gi1X1LcCvAu+e9qCSpOGGnLnfASxW1ZmqOg88BOxbtWYf8Huj248AdyfJ9MaUJK1FqurKC5I3Anuq6sdG228CvrOqDo6t+cxozdJo+59Haz6/6rEOAAdGm68GnprWExngZuDzE1f14/PeXHze/b2yquYmLdrQd8tU1RHgyEYe86IkC1U1P4tjz5LPe3PxeeuiIZdlzgE7xra3j/Zdck2SrcArgOenMaAkae2GxP0EsCvJziQ3APcCR1etOQr8yOj2G4FP1KTrPZKkdTPxskxVXUhyEDgObAE+WFUnkzwILFTVUeADwIeTLAJfYOUPgGvNTC4HXQN83puLz1vAgBdUJUnXHz+hKkkNGXdJaqh93Cd9dUJHSXYkeTTJqSQnk7x91jNtpCRbkjyZ5M9mPctGSnJTkkeS/GOS00m+a9YzbYQkPzn6Of9Mko8mecmsZ7oWtI77wK9O6OgC8I6q2g3cCbx1kzzvi94OnJ71EDPw68BfVNWtwLezCf4dJNkG/DgwX1WvYeVNH9fiGzo2XOu4M+yrE9qpqmer6pOj2//Byv/k22Y71cZIsh14A/D+Wc+ykZK8AvgeVt65RlWdr6p/m+1UG2Yr8LWjz9i8FPjcjOe5JnSP+zbg7Nj2EpskcheNvqHzNuDx2U6yYX4N+Bngf2Y9yAbbCSwDvzu6JPX+JDfOeqj1VlXngF8BngGeBb5UVX8526muDd3jvqkleRnwJ8BPVNW/z3qe9ZbkB4DnquqJWc8yA1uB24H3VdVtwH8C7V9jSvL1rPxtfCfwTcCNSe6b7VTXhu5xH/LVCS0leTErYf9IVX1s1vNskLuAvUmeZuUS3Pcl+YPZjrRhloClqrr4N7RHWIl9d68DPltVy1X1FeBjwHfPeKZrQve4D/nqhHZGX7f8AeB0Vb131vNslKp6Z1Vtr6pbWPlv/Ymq2hRncVX1r8DZJK8e7bobODXDkTbKM8CdSV46+rm/m03wQvIQ1+XvUB3qcl+dMOOxNsJdwJuAf0jyqdG+n62qYzOcSevvbcBHRicyZ4A3z3iedVdVjyd5BPgkK+8SexK/igDw6wckqaXul2UkaVMy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJauh/AbXHBCrzkQlPAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_probs_bar(vals):\n",
" H, = vals.size()\n",
" xs = list(range(H))\n",
" vals = vals/vals.sum()\n",
" ys = vals.tolist()\n",
" plt.ylim(0, 1)\n",
" plt.bar(xs, ys)\n",
"\n",
"plot_probs_bar(probs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Draw\n",
"\n",
"We're drawing index `i` with probability `probs[i]`. The following code does the draw.\n",
"\n",
"Sample a value from a uniform distribution. Use CDF to sample based on the probability.\n",
"\n",
"Instead of argmax, we're using drawing with `probs` to get a particular index."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"def draw(probs):\n",
" val = random.random()\n",
" csum = 0\n",
" for i, p in enumerate(probs):\n",
" csum += p\n",
" if csum > val:\n",
" return i \n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Verifying \n",
"\n",
"I make 10000 random draws from `probs[i]` and see how the samples are distributed. This verifies the sampling is woking, if the graph looks similar to the one above.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADM9JREFUeJzt3X+s3Xddx/Hni5aJDGQmuybaVtbEwmyIZvNmTpcocZh0YNo/JGZNhkom/YciKtEUNdPMf0AM/kgqWgFRROacxDRSrYnMmBi39I4h0taZa5nrLTO7DEQj0dL49o97ao6Xtud713Pvad/3+UiW3O/3fHK+77N1z377Ped8m6pCktTLi2Y9gCRp+oy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkMT457kg0meS/KZyzyeJL+RZDHJp5PcPv0xJUlrMeTM/UPAnis8fg+wa/TPAeB9Vz+WJOlqTIx7Vf0N8IUrLNkH/H6teAy4Kck3TmtASdLabZ3Cc2wDzo5tL432Pbt6YZIDrJzdc+ONN37HrbfeOoXDS9Lm8cQTT3y+quYmrZtG3AerqiPAEYD5+flaWFjYyMNL0nUvyb8MWTeNT8ucA3aMbW8f7ZMkzcg04n4U+OHRp2buBL5UVV91SUaStHEmXpZJ8lHgtcDNSZaAXwBeDFBVvwUcA14PLAJfBt68XsNKkoaZGPeq2j/h8QLeOrWJJElXzW+oSlJDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1NCguCfZk+SpJItJDl3i8W9O8miSJ5N8Osnrpz+qJGmoiXFPsgU4DNwD7Ab2J9m9atnPAw9X1W3AvcBvTntQSdJwQ87c7wAWq+pMVZ0HHgL2rVpTwNeNfn4F8LnpjShJWqshcd8GnB3bXhrtG/eLwH1JloBjwNsu9URJDiRZSLKwvLz8AsaVJA0xrTdU9wMfqqrtwOuBDyf5queuqiNVNV9V83Nzc1M6tCRptSFxPwfsGNvePto37n7gYYCq+jvgJcDN0xhQkrR2Q+J+AtiVZGeSG1h5w/ToqjXPAHcDJPlWVuLudRdJmpGJca+qC8BB4DhwmpVPxZxM8mCSvaNl7wDekuTvgY8CP1pVtV5DS5KubOuQRVV1jJU3Ssf3PTD28yngrumOJkl6ofyGqiQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpoUFxT7InyVNJFpMcusyaH0pyKsnJJH843TElSWuxddKCJFuAw8D3A0vAiSRHq+rU2JpdwDuBu6rqi0m+Yb0GliRNNuTM/Q5gsarOVNV54CFg36o1bwEOV9UXAarquemOKUlaiyFx3wacHdteGu0b9yrgVUn+NsljSfZc6omSHEiykGRheXn5hU0sSZpoWm+obgV2Aa8F9gO/k+Sm1Yuq6khVzVfV/Nzc3JQOLUlabUjczwE7xra3j/aNWwKOVtVXquqzwD+xEntJ0gwMifsJYFeSnUluAO4Fjq5a86esnLWT5GZWLtOcmeKckqQ1mBj3qroAHASOA6eBh6vqZJIHk+wdLTsOPJ/kFPAo8NNV9fx6DS1JurJU1UwOPD8/XwsLCzM5tiRdr5I8UVXzk9b5DVVJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaGhT3JHuSPJVkMcmhK6z7wSSVZH56I0qS1mpi3JNsAQ4D9wC7gf1Jdl9i3cuBtwOPT3tISdLaDDlzvwNYrKozVXUeeAjYd4l1vwS8G/ivKc4nSXoBhsR9G3B2bHtptO//JLkd2FFVH7/SEyU5kGQhycLy8vKah5UkDXPVb6gmeRHwXuAdk9ZW1ZGqmq+q+bm5uas9tCTpMobE/RywY2x7+2jfRS8HXgP8dZKngTuBo76pKkmzMyTuJ4BdSXYmuQG4Fzh68cGq+lJV3VxVt1TVLcBjwN6qWliXiSVJE02Me1VdAA4Cx4HTwMNVdTLJg0n2rveAkqS12zpkUVUdA46t2vfAZda+9urHkiRdDb+hKkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDQ36m5ikWw59fN2P8fS73rDux5A2C8/cJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDU0KO5J9iR5KslikkOXePynkpxK8ukkf5XkldMfVZI01MS4J9kCHAbuAXYD+5PsXrXsSWC+qr4NeAT45WkPKkkabsiZ+x3AYlWdqarzwEPAvvEFVfVoVX15tPkYsH26Y0qS1mJI3LcBZ8e2l0b7Lud+4M8v9UCSA0kWkiwsLy8Pn1KStCZTfUM1yX3APPCeSz1eVUeqar6q5ufm5qZ5aEnSmK0D1pwDdoxtbx/t+3+SvA74OeB7q+q/pzOeJOmFGHLmfgLYlWRnkhuAe4Gj4wuS3Ab8NrC3qp6b/piSpLWYGPequgAcBI4Dp4GHq+pkkgeT7B0tew/wMuCPk3wqydHLPJ0kaQMMuSxDVR0Djq3a98DYz6+b8lySpKvgN1QlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDU06EtMkmbjlkMfX9fnf/pdb5jZsScdX1fHM3dJasi4S1JDxl2SGvKau64Ls7z2LF2PjLskjenyRrKXZSSpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIe8usUZf7TkjqzTN3SWrIM3dpAm83rOuRZ+6S1JBxl6SGjLskNeQ19+uIn9TRZuGv9avnmbskNXRdnrn7u7okXZln7pLU0KC4J9mT5Kkki0kOXeLxr0nyR6PHH09yy7QHlSQNNzHuSbYAh4F7gN3A/iS7Vy27H/hiVX0L8KvAu6c9qCRpuCFn7ncAi1V1pqrOAw8B+1at2Qf83ujnR4C7k2R6Y0qS1iJVdeUFyRuBPVX1Y6PtNwHfWVUHx9Z8ZrRmabT9z6M1n1/1XAeAA6PNVwNPTeuFDHAz8PmJq/rxdW8uvu7+XllVc5MWbeinZarqCHBkI495UZKFqpqfxbFnyde9ufi6ddGQyzLngB1j29tH+y65JslW4BXA89MYUJK0dkPifgLYlWRnkhuAe4Gjq9YcBX5k9PMbgU/UpOs9kqR1M/GyTFVdSHIQOA5sAT5YVSeTPAgsVNVR4APAh5MsAl9g5TeAa81MLgddA3zdm4uvW8CAN1QlSdcfv6EqSQ0Zd0lqqH3cJ906oaMkO5I8muRUkpNJ3j7rmTZSki1JnkzyZ7OeZSMluSnJI0n+McnpJN8165k2QpKfHP06/0ySjyZ5yaxnuha0jvvAWyd0dAF4R1XtBu4E3rpJXvdFbwdOz3qIGfh14C+q6lbg29kE/w6SbAN+HJivqtew8qGPa/EDHRuuddwZduuEdqrq2ar65Ojn/2Dlf/Jts51qYyTZDrwBeP+sZ9lISV4BfA8rn1yjqs5X1b/NdqoNsxX42tF3bF4KfG7G81wTusd9G3B2bHuJTRK5i0Z36LwNeHy2k2yYXwN+BvifWQ+ywXYCy8Dvji5JvT/JjbMear1V1TngV4BngGeBL1XVX852qmtD97hvakleBvwJ8BNV9e+znme9JfkB4LmqemLWs8zAVuB24H1VdRvwn0D795iSfD0rfxrfCXwTcGOS+2Y71bWhe9yH3DqhpSQvZiXsH6mqj816ng1yF7A3ydOsXIL7viR/MNuRNswSsFRVF/+E9ggrse/udcBnq2q5qr4CfAz47hnPdE3oHvcht05oZ3S75Q8Ap6vqvbOeZ6NU1TurantV3cLKf+tPVNWmOIurqn8FziZ59WjX3cCpGY60UZ4B7kzy0tGv+7vZBG8kD3Fd/h2qQ13u1gkzHmsj3AW8CfiHJJ8a7fvZqjo2w5m0/t4GfGR0InMGePOM51l3VfV4kkeAT7LyKbEn8VYEgLcfkKSWul+WkaRNybhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJamh/wWYpQQoR6/iZQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def check_sampling(probs):\n",
" from collections import defaultdict\n",
"\n",
" max_samples = 10000\n",
" counter = defaultdict(int)\n",
" for i in range(max_samples):\n",
" choice = draw(probs)\n",
" counter[choice] += 1\n",
"\n",
" \n",
" vals = torch.Tensor(list(counter.values()))\n",
" return plot_probs_bar(vals)\n",
"\n",
"check_sampling(probs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Temperature"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class TSoftmax(nn.Module):\n",
" def __init__(self, temperature):\n",
" super().__init__()\n",
" self.T = temperature\n",
" \n",
" def forward(self, tensor):\n",
" # Take all samples, divide them by T, pass through exp(x)\n",
" entries = tensor.data.view(-1).div(self.T).exp()\n",
" return entries/entries.sum()\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### What temperature does to probabilities\n",
"We try to use temperature and see how the probabilties end up different."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x2160 with 7 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"m = 4\n",
"less_than_one = [2**(-b) for b in range(1, m)]\n",
"greater_than_one = [2**(b) for b in range(1, m)]\n",
"\n",
"tvals = list(reversed(less_than_one)) + [1] + greater_than_one\n",
"n_tvals = len(tvals)\n",
"\n",
"plt.figure(figsize=(10, 30))\n",
"\n",
"for i, T in enumerate(tvals):\n",
" v = i+1\n",
" ax1 = plt.subplot(n_tvals,1,v)\n",
" ax1.set_title(\"t = {}\".format(T))\n",
" transform = TSoftmax(temperature = T)\n",
" tprobs = transform(acts)\n",
" plot_probs_bar(tprobs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To summarize, as $ T \\rightarrow \\infty $ it renders the probabilities into a uniform distribution. All samples are equally probable. This is not good, our char rnn will be predicting garbage.\n",
"\n",
"What we want is the opposite. We want to chose one from the more likelier ones in each go. The choosing part adds the randomness, and the temperature changes the distribution so the less likelier aren't chosen at all. Hence we go for $T < 1$. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@Deepayan137
Copy link

Great work!! You can write a blog on this :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment