Skip to content

Instantly share code, notes, and snippets.

@jamestwebber
Created October 23, 2018 23:41
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jamestwebber/b93e486e32a33da5bd99a330608a2e9d to your computer and use it in GitHub Desktop.
Save jamestwebber/b93e486e32a33da5bd99a330608a2e9d to your computer and use it in GitHub Desktop.
A iPython notebook showing how to use SVI for logistic regression in Pyro
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import scipy.special as ssp\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.distributions.constraints as constraints\n",
"\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"\n",
"from pyro.infer import SVI, Trace_ELBO\n",
"from pyro.optim import Adam, SGD\n",
"\n",
"pyro.enable_validation(True)\n",
"torch.set_default_dtype(torch.double) # this was necessary on the CPU"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def build_logistic_dataset(N, p=1, noise_std=0.01):\n",
" X = np.random.randn(N, p)\n",
" \n",
" w = np.random.randn(p)\n",
" w += 2 * np.sign(w)\n",
"\n",
" y = np.round(ssp.expit(np.matmul(X, w) \n",
" + np.repeat(1, N) \n",
" + np.random.normal(0, noise_std, size=N)))\n",
" y = y.reshape(N, 1)\n",
" return X, y, w\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# these were adapted from the Pyro VAE tutorial\n",
"\n",
"def train(svi, train_loader, n_train):\n",
" # initialize loss accumulator\n",
" epoch_loss = 0.\n",
" # do a training epoch over each mini-batch x returned\n",
" # by the data loader\n",
" for _, xs in enumerate(train_loader):\n",
" # do ELBO gradient and accumulate loss\n",
" epoch_loss += svi.step(*xs)\n",
"\n",
" # return epoch loss\n",
" total_epoch_loss_train = epoch_loss / n_train\n",
" return total_epoch_loss_train\n",
"\n",
"\n",
"def evaluate(svi, test_loader, n_test):\n",
" # initialize loss accumulator\n",
" test_loss = 0.\n",
" # compute the loss over the entire test set\n",
" for _, xs in enumerate(test_loader):\n",
" # compute ELBO estimate and accumulate loss\n",
" test_loss += svi.evaluate_loss(*xs)\n",
"\n",
" total_epoch_loss_test = test_loss / n_test\n",
" return total_epoch_loss_test\n",
"\n",
"\n",
"def plot_llk(train_elbo, test_elbo, test_int):\n",
" plt.figure(figsize=(8, 6))\n",
"\n",
" x = np.arange(len(train_elbo))\n",
"\n",
" plt.plot(x, train_elbo, marker='o', label='Train ELBO')\n",
" plt.plot(x[::test_int], test_elbo, marker='o', label='Test ELBO')\n",
" plt.xlabel('Training Epoch')\n",
" plt.legend()\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class LogRegressionModel(nn.Module):\n",
" def __init__(self, p):\n",
" super(LogRegressionModel, self).__init__()\n",
" \n",
" self.p = p\n",
"\n",
" # hyperparameters for normal priors\n",
" self.alpha_h_loc = torch.zeros(1, p)\n",
" self.alpha_h_scale = 10.0 * torch.ones(1, p)\n",
" self.beta_h_loc = torch.zeros(1)\n",
" self.beta_h_scale = 10.0 * torch.ones(1)\n",
" \n",
" # initial values of variational parameters\n",
" self.alpha_0 = np.zeros((1, p))\n",
" self.alpha_0_scale = np.ones((1, p))\n",
" self.beta_0 = np.zeros((1,))\n",
" self.beta_0_scale = np.ones((1,))\n",
"\n",
" def model(self, x, y):\n",
" # sample from prior\n",
" a = pyro.sample(\n",
" \"weight\", dist.Normal(self.alpha_h_loc, self.alpha_h_scale, validate_args=True).independent(1)\n",
" )\n",
" b = pyro.sample(\n",
" \"bias\", dist.Normal(self.beta_h_loc, self.beta_h_scale, validate_args=True).independent(1)\n",
" )\n",
"\n",
" with pyro.iarange(\"data\", x.size(0)):\n",
" model_logits = (torch.matmul(x, a.permute(1, 0)) + b).squeeze()\n",
" \n",
" pyro.sample(\n",
" \"obs\", \n",
" dist.Bernoulli(logits=model_logits, validate_args=True),\n",
" obs=y.squeeze()\n",
" )\n",
" \n",
" def guide(self, x, y):\n",
" # register variational parameters with pyro\n",
" alpha_loc = pyro.param(\"alpha_loc\", torch.tensor(self.alpha_0))\n",
" alpha_scale = pyro.param(\"alpha_scale\", torch.tensor(self.alpha_0_scale),\n",
" constraint=constraints.positive)\n",
" beta_loc = pyro.param(\"beta_loc\", torch.tensor(self.beta_0))\n",
" beta_scale = pyro.param(\"beta_scale\", torch.tensor(self.beta_0_scale),\n",
" constraint=constraints.positive)\n",
"\n",
" pyro.sample(\n",
" \"weight\", dist.Normal(alpha_loc, alpha_scale, validate_args=True).independent(1)\n",
" )\n",
" pyro.sample(\n",
" \"bias\", dist.Normal(beta_loc, beta_scale, validate_args=True).independent(1)\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"pyro.clear_param_store()\n",
"\n",
"optim = Adam({'lr': 0.01})\n",
"\n",
"num_epochs = 1000\n",
"batch_size = 50\n",
"\n",
"N = 1000\n",
"p = 3\n",
"\n",
"X, y, w = build_logistic_dataset(N, p)\n",
"\n",
"example_indices = np.random.permutation(N)\n",
"n_train = int(0.9 * N) # 90%/10% train/test split\n",
"n_test = N - n_train\n",
"test_iter = 50\n",
"\n",
"X = torch.from_numpy(X)\n",
"y = torch.from_numpy(y)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[epoch 000] average training loss: 1.2700\n",
"[epoch 050] average training loss: 0.3476\n",
"[epoch 100] average training loss: 0.2844\n",
"[epoch 150] average training loss: 0.2586\n",
"[epoch 200] average training loss: 0.2638\n",
"[epoch 250] average training loss: 0.2504\n",
"[epoch 300] average training loss: 0.2299\n",
"[epoch 350] average training loss: 0.2368\n",
"[epoch 400] average training loss: 0.2376\n",
"[epoch 450] average training loss: 0.2421\n",
"[epoch 500] average training loss: 0.2412\n",
"[epoch 550] average training loss: 0.2241\n",
"[epoch 600] average training loss: 0.2325\n",
"[epoch 650] average training loss: 0.2162\n",
"[epoch 700] average training loss: 0.2396\n",
"[epoch 750] average training loss: 0.2344\n",
"[epoch 800] average training loss: 0.2515\n",
"[epoch 850] average training loss: 0.2385\n",
"[epoch 900] average training loss: 0.2319\n",
"[epoch 950] average training loss: 0.2378\n"
]
}
],
"source": [
"lr_model = LogRegressionModel(p=p)\n",
"\n",
"svi = SVI(\n",
" lr_model.model, lr_model.guide, optim,\n",
" loss=Trace_ELBO(), use_cuda=False\n",
")\n",
"\n",
"\n",
"lr_dataset = torch.utils.data.TensorDataset(X, y)\n",
"\n",
"data_loader_train = DataLoader(\n",
" dataset=lr_dataset, batch_size=batch_size, pin_memory=False,\n",
" sampler=SubsetRandomSampler(example_indices[:n_train]),\n",
")\n",
" \n",
"data_loader_test = DataLoader(\n",
" dataset=lr_dataset, batch_size=batch_size, pin_memory=False,\n",
" sampler=SubsetRandomSampler(example_indices[n_train:]),\n",
")\n",
"\n",
"train_elbo = []\n",
"test_elbo = []\n",
"for epoch in range(num_epochs):\n",
" total_epoch_loss_train = train(svi, data_loader_train, n_train)\n",
" train_elbo.append(-total_epoch_loss_train)\n",
"\n",
" if epoch % test_iter == 0:\n",
" print(\"[epoch %03d] average training loss: %.4f\" % (epoch, total_epoch_loss_train))\n",
" # report test diagnostics\n",
" total_epoch_loss_test = evaluate(svi, data_loader_test, n_test)\n",
" test_elbo.append(-total_epoch_loss_test)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_llk(train_elbo, test_elbo, test_iter)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:dose_ensemble]",
"language": "python",
"name": "conda-env-dose_ensemble-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment