-
-
Save dilaragokay/75dbbdae23d92af4366fe96c95a018f8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "ReparamGumbel.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "JV7ddDYmvPvk", | |
"outputId": "f92ebe02-5b12-445a-b74d-5d2ebeb37718" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting pyro-ppl\n", | |
" Downloading pyro_ppl-1.8.1-py3-none-any.whl (718 kB)\n", | |
"\u001b[K |████████████████████████████████| 718 kB 35.3 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: tqdm>=4.36 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (4.64.0)\n", | |
"Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.21.6)\n", | |
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (3.3.0)\n", | |
"Requirement already satisfied: torch>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.11.0+cu113)\n", | |
"Collecting pyro-api>=0.1.1\n", | |
" Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.11.0->pyro-ppl) (4.1.1)\n", | |
"Installing collected packages: pyro-api, pyro-ppl\n", | |
"Successfully installed pyro-api-0.1.2 pyro-ppl-1.8.1\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install pyro-ppl" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"import pyro\n", | |
"import pyro.distributions as dist\n", | |
"from pyro import poutine\n", | |
"from pyro.infer.reparam import GumbelSoftmaxReparam\n", | |
"from pyro.infer.autoguide import AutoNormal\n", | |
"from pyro.infer import SVI, Trace_ELBO\n", | |
"from pyro.optim import Adam" | |
], | |
"metadata": { | |
"id": "6jC00u8Sv-uH" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"shape = (4,)\n", | |
"dim = 2\n", | |
"\n", | |
"temperature = torch.tensor(0.1)\n", | |
"logits = torch.randn(shape + (dim,))" | |
], | |
"metadata": { | |
"id": "APO61rfPwFrC" | |
}, | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def model():\n", | |
" with pyro.plate_stack(\"plates\", shape):\n", | |
" with pyro.plate(\"particles\", 10000):\n", | |
" pyro.sample(\"x\", dist.RelaxedOneHotCategorical(temperature,\n", | |
" logits=logits))" | |
], | |
"metadata": { | |
"id": "LcFUDaqfwb7p" | |
}, | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"guide = AutoNormal(model)\n", | |
"reparam_model = poutine.reparam(model, {\"x\": GumbelSoftmaxReparam()})\n", | |
"\n", | |
"elbo = Trace_ELBO()\n", | |
"adam_params = {\"lr\": 0.001, \"betas\": (0.95, 0.999)}\n", | |
"optimizer = Adam(adam_params)\n", | |
"\n", | |
"svi = SVI(\n", | |
" reparam_model,\n", | |
" guide,\n", | |
" optimizer,\n", | |
" loss=elbo,\n", | |
")" | |
], | |
"metadata": { | |
"id": "azRK5ohZv_m6" | |
}, | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for _ in range(100):\n", | |
" loss = svi.step()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5lZiUu31xJcm", | |
"outputId": "c5254304-a8fb-4ae8-e091-a9030e7578ef" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/pyro/util.py:291: UserWarning: Found non-auxiliary vars in guide but not model, consider marking these infer={'is_auxiliary': True}:\n", | |
"{'x'}\n", | |
" guide_vars - aux_vars - model_vars\n", | |
"/usr/local/lib/python3.7/dist-packages/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_uniform'}\n", | |
" warnings.warn(f\"Found vars in model but not guide: {bad_sites}\")\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment