Skip to content

Instantly share code, notes, and snippets.

@dilaragokay
Created July 4, 2022 15:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dilaragokay/75dbbdae23d92af4366fe96c95a018f8 to your computer and use it in GitHub Desktop.
Save dilaragokay/75dbbdae23d92af4366fe96c95a018f8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "ReparamGumbel.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JV7ddDYmvPvk",
"outputId": "f92ebe02-5b12-445a-b74d-5d2ebeb37718"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting pyro-ppl\n",
" Downloading pyro_ppl-1.8.1-py3-none-any.whl (718 kB)\n",
"\u001b[K |████████████████████████████████| 718 kB 35.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: tqdm>=4.36 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (4.64.0)\n",
"Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.21.6)\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (3.3.0)\n",
"Requirement already satisfied: torch>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.11.0+cu113)\n",
"Collecting pyro-api>=0.1.1\n",
" Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.11.0->pyro-ppl) (4.1.1)\n",
"Installing collected packages: pyro-api, pyro-ppl\n",
"Successfully installed pyro-api-0.1.2 pyro-ppl-1.8.1\n"
]
}
],
"source": [
"!pip install pyro-ppl"
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"from pyro import poutine\n",
"from pyro.infer.reparam import GumbelSoftmaxReparam\n",
"from pyro.infer.autoguide import AutoNormal\n",
"from pyro.infer import SVI, Trace_ELBO\n",
"from pyro.optim import Adam"
],
"metadata": {
"id": "6jC00u8Sv-uH"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"shape = (4,)\n",
"dim = 2\n",
"\n",
"temperature = torch.tensor(0.1)\n",
"logits = torch.randn(shape + (dim,))"
],
"metadata": {
"id": "APO61rfPwFrC"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def model():\n",
" with pyro.plate_stack(\"plates\", shape):\n",
" with pyro.plate(\"particles\", 10000):\n",
" pyro.sample(\"x\", dist.RelaxedOneHotCategorical(temperature,\n",
" logits=logits))"
],
"metadata": {
"id": "LcFUDaqfwb7p"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"guide = AutoNormal(model)\n",
"reparam_model = poutine.reparam(model, {\"x\": GumbelSoftmaxReparam()})\n",
"\n",
"elbo = Trace_ELBO()\n",
"adam_params = {\"lr\": 0.001, \"betas\": (0.95, 0.999)}\n",
"optimizer = Adam(adam_params)\n",
"\n",
"svi = SVI(\n",
" reparam_model,\n",
" guide,\n",
" optimizer,\n",
" loss=elbo,\n",
")"
],
"metadata": {
"id": "azRK5ohZv_m6"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for _ in range(100):\n",
" loss = svi.step()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5lZiUu31xJcm",
"outputId": "c5254304-a8fb-4ae8-e091-a9030e7578ef"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/pyro/util.py:291: UserWarning: Found non-auxiliary vars in guide but not model, consider marking these infer={'is_auxiliary': True}:\n",
"{'x'}\n",
" guide_vars - aux_vars - model_vars\n",
"/usr/local/lib/python3.7/dist-packages/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_uniform'}\n",
" warnings.warn(f\"Found vars in model but not guide: {bad_sites}\")\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment