-
-
Save Mr4k/fbb096baf20354b3fdcbd082a00e20d6 to your computer and use it in GitHub Desktop.
MonteCarloGumbelSoftmax.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "MonteCarloGumbelSoftmax.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyPN4EtCkSU8VC3MZppjVc0g", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Mr4k/fbb096baf20354b3fdcbd082a00e20d6/montecarlogumbelsoftmax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aNXUXI2hFQUN", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import jax\n", | |
"import jax.numpy as np\n", | |
"# current convention is to import original numpy as \"onp\"\n", | |
"import numpy as onp\n", | |
"from jax import grad, jit, vmap, random\n", | |
"from jax.ops import index, index_add, index_update\n", | |
"import jax.experimental.optimizers as optimizers" | |
], | |
"execution_count": 134, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-H0zehkiLe9F", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Gumbel Softmax Trick\n", | |
"# this code is intended to be clear rather than fast\n", | |
"@jit\n", | |
"def sample_from_gumbel_softmax(key, catagorical_probs, temp):\n", | |
" sample = jax.random.gumbel(key, catagorical_probs.shape)\n", | |
" return jax.nn.softmax((np.log(catagorical_probs) + sample) / temp)" | |
], | |
"execution_count": 135, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sQwnSxdnO95c", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 85 | |
}, | |
"outputId": "8e138843-db96-4569-f2fe-67c67868a7d8" | |
}, | |
"source": [ | |
"# keep the seed reproduceable\n", | |
"RNG_KEY = random.PRNGKey(0)\n", | |
"onp.random.seed(3)\n", | |
"\n", | |
"# define catagorical distribution\n", | |
"catagorical_probs = onp.random.uniform(size=4)\n", | |
"catagorical_probs /= catagorical_probs.sum()\n", | |
"\n", | |
"# define weights for catagorical distribution\n", | |
"weights = onp.random.uniform(size=4)\n", | |
"\n", | |
"# f is a really simple sum of weights in a catagorical distribution\n", | |
"# note that we must normalize the catagorical distrubution to get the correct derivative through backprop\n", | |
"def expected_f(catagorical_probs):\n", | |
" return weights.T @ catagorical_probs / catagorical_probs.sum()\n", | |
"\n", | |
"samples = 10000\n", | |
"random_keys = random.split(RNG_KEY, samples)\n", | |
"\n", | |
"def sample_from_f(key, catagorical_probs):\n", | |
" return weights.T @ sample_from_gumbel_softmax(key, catagorical_probs, 0.1)\n", | |
"\n", | |
"# we can also calculate f with a Monte Carlo Simulation to check our Gumbel Softmax Implementation\n", | |
"def monte_carlo_f(catagorical_probs, random_keys):\n", | |
" sampler = vmap(lambda key: sample_from_f(key, catagorical_probs))\n", | |
" samples = sampler(random_keys)\n", | |
" num_samples = random_keys.shape[0]\n", | |
" return samples.sum()/num_samples\n", | |
"\n", | |
"print(\"Monte Carlo Approximation of E_x[f]:\", monte_carlo_f(catagorical_probs, random_keys))\n", | |
"print(\"True E_x[f]:\", expected_f(catagorical_probs))\n", | |
"\n", | |
"estimated_grad = grad(monte_carlo_f)(catagorical_probs, random_keys)\n", | |
"true_grad = grad(expected_f)(catagorical_probs)\n", | |
"\n", | |
"print(\"Monte Carlo Approximation of d/dx E_x[f]:\", estimated_grad)\n", | |
"print(\"True d/dx E_x[f]:\", true_grad)\n" | |
], | |
"execution_count": 136, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Monte Carlo Approximation of E_x[f]: 0.6162573\n", | |
"True E_x[f]: 0.615787624912864\n", | |
"Monte Carlo Approximation of d/dx E_x[f]: [ 0.26888195 0.28707793 -0.48471233 -0.41185737]\n", | |
"True d/dx E_x[f]: [ 0.27715933 0.28050548 -0.4902023 -0.40854475]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment