Skip to content

Instantly share code, notes, and snippets.

@AseiSugiyama
Created December 6, 2020 13:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AseiSugiyama/2e0211035bd14ebbbe60fdb3b48e438f to your computer and use it in GitHub Desktop.
Save AseiSugiyama/2e0211035bd14ebbbe60fdb3b48e438f to your computer and use it in GitHub Desktop.
numpyro_vae.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "numpyro_vae.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPKDuXuOiukiElV+kW/eTNk",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/AseiSugiyama/2e0211035bd14ebbbe60fdb3b48e438f/numpyro_vae.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aXDRxifjo9Y8"
},
"source": [
"!pip install --upgrade pip\n",
"!pip install --upgrade numpyro\n",
"!pip install --upgrade jax jaxlib==0.1.56+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Z6MUOZNsikub"
},
"source": [
"import inspect\n",
"import os\n",
"import time\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from jax import jit, lax, random\n",
"from jax.experimental import stax\n",
"import jax.numpy as jnp\n",
"from jax.random import PRNGKey\n",
"\n",
"import numpyro\n",
"from numpyro import optim\n",
"import numpyro.distributions as dist\n",
"from numpyro.examples.datasets import MNIST, load_dataset\n",
"from numpyro.infer import SVI, Trace_ELBO"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qr1AgIJbAuT5"
},
"source": [
"from numpyro.util import set_platform\n",
"set_platform(\"gpu\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zYtpjxzKiqfN"
},
"source": [
"RESULTS_DIR = os.path.abspath(os.path.join(os.path.dirname(inspect.getfile(lambda: None)),\n",
" '.results'))\n",
"os.makedirs(RESULTS_DIR, exist_ok=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1BSYjejNisdN"
},
"source": [
"def encoder(hidden_dim, z_dim):\n",
" return stax.serial(\n",
" stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,\n",
" stax.FanOut(2),\n",
" stax.parallel(stax.Dense(z_dim, W_init=stax.randn()),\n",
" stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),\n",
" )\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "INdM_w5liu_R"
},
"source": [
"def decoder(hidden_dim, out_dim):\n",
" return stax.serial(\n",
" stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,\n",
" stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid,\n",
" )\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "V7SXnrkQixAY"
},
"source": [
"def model(batch, hidden_dim=400, z_dim=100):\n",
" batch = jnp.reshape(batch, (batch.shape[0], -1))\n",
" batch_dim, out_dim = jnp.shape(batch)\n",
" decode = numpyro.module('decoder', decoder(hidden_dim, out_dim), (batch_dim, z_dim))\n",
" z = numpyro.sample('z', dist.Normal(jnp.zeros((z_dim,)), jnp.ones((z_dim,))))\n",
" img_loc = decode(z)\n",
" return numpyro.sample('obs', dist.Bernoulli(img_loc), obs=batch)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_s_ZNrpoizkB"
},
"source": [
"def guide(batch, hidden_dim=400, z_dim=100):\n",
" batch = jnp.reshape(batch, (batch.shape[0], -1))\n",
" batch_dim, out_dim = jnp.shape(batch)\n",
" encode = numpyro.module('encoder', encoder(hidden_dim, z_dim), (batch_dim, out_dim))\n",
" z_loc, z_std = encode(batch)\n",
" z = numpyro.sample('z', dist.Normal(z_loc, z_std))\n",
" return z\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8A0G_4Avi1cV"
},
"source": [
"@jit\n",
"def binarize(rng_key, batch):\n",
" return random.bernoulli(rng_key, batch).astype(batch.dtype)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "S9RYrODii35S"
},
"source": [
"hidden_dim = 400\n",
"z_dim = 50\n",
"learning_rate = 1.0e-3\n",
"batch_size = 256\n",
"num_epochs = 100\n",
"\n",
"\n",
"encoder_nn = encoder(hidden_dim, z_dim)\n",
"decoder_nn = decoder(hidden_dim, 28 * 28)\n",
"adam = optim.Adam(learning_rate)\n",
"svi = SVI(model, guide, adam, Trace_ELBO(), hidden_dim=hidden_dim, z_dim=z_dim)\n",
"rng_key = PRNGKey(0)\n",
"train_init, train_fetch = load_dataset(MNIST, batch_size=batch_size, split='train')\n",
"test_init, test_fetch = load_dataset(MNIST, batch_size=batch_size, split='test')\n",
"num_train, train_idx = train_init()\n",
"rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)\n",
"sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])\n",
"svi_state = svi.init(rng_key_init, sample_batch)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hoG5Cs7MldrJ"
},
"source": [
"@jit\n",
"def epoch_train(svi_state, rng_key):\n",
" def body_fn(i, val):\n",
" loss_sum, svi_state = val\n",
" rng_key_binarize = random.fold_in(rng_key, i)\n",
" batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])\n",
" svi_state, loss = svi.update(svi_state, batch)\n",
" loss_sum += loss\n",
" return loss_sum, svi_state\n",
"\n",
" return lax.fori_loop(0, num_train, body_fn, (0., svi_state))\n",
"\n",
"@jit\n",
"def eval_test(svi_state, rng_key):\n",
" def body_fun(i, loss_sum):\n",
" rng_key_binarize = random.fold_in(rng_key, i)\n",
" batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])\n",
" # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?\n",
" loss = svi.evaluate(svi_state, batch) / len(batch)\n",
" loss_sum += loss\n",
" return loss_sum\n",
"\n",
" loss = lax.fori_loop(0, num_test, body_fun, 0.)\n",
" loss = loss / num_test\n",
" return loss"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RFkmEgRMloxR"
},
"source": [
"%%time\n",
"def reconstruct_img(epoch, rng_key):\n",
" img = test_fetch(0, test_idx)[0][0]\n",
" plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray')\n",
" rng_key_binarize, rng_key_sample = random.split(rng_key)\n",
" test_sample = binarize(rng_key_binarize, img)\n",
" params = svi.get_params(svi_state)\n",
" z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1]))\n",
" z = dist.Normal(z_mean, z_var).sample(rng_key_sample)\n",
" img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])\n",
" plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray')\n",
"\n",
"for i in range(num_epochs):\n",
" rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(rng_key, 4)\n",
" t_start = time.time()\n",
" num_train, train_idx = train_init()\n",
" _, svi_state = epoch_train(svi_state, rng_key_train)\n",
" rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)\n",
" num_test, test_idx = test_init()\n",
" test_loss = eval_test(svi_state, rng_key_test)\n",
" reconstruct_img(i, rng_key_reconstruct)\n",
" print(\"Epoch {}: loss = {} ({:.2f} s.)\".format(i, test_loss, time.time() - t_start))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZZJKEV2ll9so"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment