Skip to content

Instantly share code, notes, and snippets.

@Mr4k
Created September 23, 2020 05:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Mr4k/fbb096baf20354b3fdcbd082a00e20d6 to your computer and use it in GitHub Desktop.
Save Mr4k/fbb096baf20354b3fdcbd082a00e20d6 to your computer and use it in GitHub Desktop.
MonteCarloGumbelSoftmax.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MonteCarloGumbelSoftmax.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPN4EtCkSU8VC3MZppjVc0g",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Mr4k/fbb096baf20354b3fdcbd082a00e20d6/montecarlogumbelsoftmax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aNXUXI2hFQUN",
"colab_type": "code",
"colab": {}
},
"source": [
"import jax\n",
"import jax.numpy as np\n",
"# current convention is to import original numpy as \"onp\"\n",
"import numpy as onp\n",
"from jax import grad, jit, vmap, random\n",
"from jax.ops import index, index_add, index_update\n",
"import jax.experimental.optimizers as optimizers"
],
"execution_count": 134,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-H0zehkiLe9F",
"colab_type": "code",
"colab": {}
},
"source": [
"# Gumbel Softmax Trick\n",
"# this code is intended to be clear rather than fast\n",
"@jit\n",
"def sample_from_gumbel_softmax(key, catagorical_probs, temp):\n",
" sample = jax.random.gumbel(key, catagorical_probs.shape)\n",
" return jax.nn.softmax((np.log(catagorical_probs) + sample) / temp)"
],
"execution_count": 135,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sQwnSxdnO95c",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "8e138843-db96-4569-f2fe-67c67868a7d8"
},
"source": [
"# keep the seed reproduceable\n",
"RNG_KEY = random.PRNGKey(0)\n",
"onp.random.seed(3)\n",
"\n",
"# define catagorical distribution\n",
"catagorical_probs = onp.random.uniform(size=4)\n",
"catagorical_probs /= catagorical_probs.sum()\n",
"\n",
"# define weights for catagorical distribution\n",
"weights = onp.random.uniform(size=4)\n",
"\n",
"# f is a really simple sum of weights in a catagorical distribution\n",
"# note that we must normalize the catagorical distrubution to get the correct derivative through backprop\n",
"def expected_f(catagorical_probs):\n",
" return weights.T @ catagorical_probs / catagorical_probs.sum()\n",
"\n",
"samples = 10000\n",
"random_keys = random.split(RNG_KEY, samples)\n",
"\n",
"def sample_from_f(key, catagorical_probs):\n",
" return weights.T @ sample_from_gumbel_softmax(key, catagorical_probs, 0.1)\n",
"\n",
"# we can also calculate f with a Monte Carlo Simulation to check our Gumbel Softmax Implementation\n",
"def monte_carlo_f(catagorical_probs, random_keys):\n",
" sampler = vmap(lambda key: sample_from_f(key, catagorical_probs))\n",
" samples = sampler(random_keys)\n",
" num_samples = random_keys.shape[0]\n",
" return samples.sum()/num_samples\n",
"\n",
"print(\"Monte Carlo Approximation of E_x[f]:\", monte_carlo_f(catagorical_probs, random_keys))\n",
"print(\"True E_x[f]:\", expected_f(catagorical_probs))\n",
"\n",
"estimated_grad = grad(monte_carlo_f)(catagorical_probs, random_keys)\n",
"true_grad = grad(expected_f)(catagorical_probs)\n",
"\n",
"print(\"Monte Carlo Approximation of d/dx E_x[f]:\", estimated_grad)\n",
"print(\"True d/dx E_x[f]:\", true_grad)\n"
],
"execution_count": 136,
"outputs": [
{
"output_type": "stream",
"text": [
"Monte Carlo Approximation of E_x[f]: 0.6162573\n",
"True E_x[f]: 0.615787624912864\n",
"Monte Carlo Approximation of d/dx E_x[f]: [ 0.26888195 0.28707793 -0.48471233 -0.41185737]\n",
"True d/dx E_x[f]: [ 0.27715933 0.28050548 -0.4902023 -0.40854475]\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment