Skip to content

Instantly share code, notes, and snippets.

@sadatnfs
Created August 28, 2018 21:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sadatnfs/f0ffa19b3de1bfd8b9ee93bce9786f7b to your computer and use it in GitHub Desktop.
Save sadatnfs/f0ffa19b3de1bfd8b9ee93bce9786f7b to your computer and use it in GitHub Desktop.
Simple Bayesian Regression example using Pyro
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import pyro\n",
"from pyro.distributions import Normal\n",
"from pyro.infer import SVI, Trace_ELBO\n",
"from pyro.optim import Adam\n",
"import pyro.distributions as dist\n",
"\n",
"# for CI testing\n",
"smoke_test = ('CI' in os.environ)\n",
"pyro.enable_validation(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Bayesian Regression \n",
"Learning a function of the form:\n",
" $$y = wX + b + \\epsilon$$"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"N = 500 # size of toy data\n",
"p=2\n",
"\n",
"## Build a simple linear dataset\n",
"def build_linear_dataset(N, p=1, noise_std=0.05, w = 3, b = 1):\n",
" X = np.random.rand(N, p)\n",
" w = w * np.ones(p)\n",
" y = np.matmul(X, w) + np.repeat(b, N) + np.random.normal(0, noise_std, size=N)\n",
" y = y.reshape(N, 1)\n",
" X, y = torch.tensor(X).type(torch.Tensor), torch.tensor(y).type(torch.Tensor)\n",
" data = torch.cat((X, y), 1)\n",
" assert data.shape == (N, p + 1)\n",
" return data\n",
"\n",
"## Define our regression model module\n",
"class RegressionModel(nn.Module):\n",
" def __init__(self, p): \n",
" super(RegressionModel, self).__init__()\n",
" # p = number of features\n",
" # 1 = number of output dimensions\n",
" self.linear = nn.Linear(p, 1) # p in and one out\n",
"\n",
" def forward(self, x):\n",
" y_pred = self.linear(x)\n",
" return y_pred\n",
"\n",
"regression_model = RegressionModel(p)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"loc = torch.zeros(1, p)\n",
"scale = torch.ones(1, p)\n",
"# define a unit normal prior\n",
"prior = Normal(loc, scale)\n",
"# overload the parameters in the regression module with samples from the prior\n",
"lifted_module = pyro.random_module(\"regression_module\", regression_model, prior)\n",
"# sample a regressor from the prior\n",
"sampled_reg_model = lifted_module()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def model(data):\n",
" # Create unit normal priors over the parameters\n",
" loc, scale = torch.zeros(1, p), 10 * torch.ones(1, p)\n",
" bias_loc, bias_scale = torch.zeros(1), 10 * torch.ones(1)\n",
" w_prior = Normal(loc, scale).independent(p)\n",
" b_prior = Normal(bias_loc, bias_scale).independent(1)\n",
" priors = {'linear.weight': w_prior, 'linear.bias': b_prior}\n",
" # lift module parameters to random variables sampled from the priors\n",
" lifted_module = pyro.random_module(\"module\", regression_model, priors)\n",
" # sample a regressor (which also samples w and b)\n",
" lifted_reg_model = lifted_module()\n",
" with pyro.iarange(\"map\", N):\n",
" x_data = data[:, :-1]\n",
" y_data = data[:, -1]\n",
"\n",
" # run the regressor forward conditioned on data\n",
" prediction_mean = lifted_reg_model(x_data).squeeze(-1)\n",
" # condition on the observed data\n",
" pyro.sample(\"obs\",\n",
" Normal(prediction_mean, 0.1 * torch.ones(data.size(0))),\n",
" obs=y_data)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"softplus = torch.nn.Softplus()\n",
"\n",
"def guide(data):\n",
" # define our variational parameters\n",
" w_loc = torch.randn(1, p)\n",
" # note that we initialize our scales to be pretty narrow\n",
" w_log_sig = torch.tensor(-3.0 * torch.ones(1, p) + 0.05 * torch.randn(1, 1))\n",
" b_loc = torch.randn(1)\n",
" b_log_sig = torch.tensor(-3.0 * torch.ones(1) + 0.05 * torch.randn(1))\n",
" # register learnable params in the param store\n",
" mw_param = pyro.param(\"guide_mean_weight\", w_loc)\n",
" sw_param = softplus(pyro.param(\"guide_log_scale_weight\", w_log_sig))\n",
" mb_param = pyro.param(\"guide_mean_bias\", b_loc)\n",
" sb_param = softplus(pyro.param(\"guide_log_scale_bias\", b_log_sig))\n",
" # guide distributions for w and b\n",
" w_dist = Normal(mw_param, sw_param).independent(p)\n",
" b_dist = Normal(mb_param, sb_param).independent(1)\n",
" dists = {'linear.weight': w_dist, 'linear.bias': b_dist}\n",
" # overload the parameters in the module with random samples\n",
" # from the guide distributions\n",
" lifted_module = pyro.random_module(\"module\", regression_model, dists)\n",
" # sample a regressor (which also samples w and b)\n",
" return lifted_module()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"optim = Adam({\"lr\": 0.05})\n",
"svi = SVI(model, guide, optim, loss=Trace_ELBO())"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"num_iterations = 1500\n",
"def main():\n",
" pyro.clear_param_store()\n",
" data = build_linear_dataset(N, p=p, w = np.array([1., 0.5]), b = 4.)\n",
" for j in range(num_iterations):\n",
" # calculate the loss and take a gradient step\n",
" loss = svi.step(data)\n",
" if j % 100 == 0:\n",
" print(\"[iteration %04d] loss: %.4f\" % (j + 1, loss / float(N)))\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[iteration 0001] loss: 1236.4488\n",
"[iteration 0101] loss: 37.6479\n",
"[iteration 0201] loss: 22.9069\n",
"[iteration 0301] loss: 10.0312\n",
"[iteration 0401] loss: 4.0211\n",
"[iteration 0501] loss: 0.7615\n",
"[iteration 0601] loss: -0.5694\n",
"[iteration 0701] loss: -1.0169\n",
"[iteration 0801] loss: -1.1444\n",
"[iteration 0901] loss: -1.1337\n",
"CPU times: user 3min 59s, sys: 927 ms, total: 4min\n",
"Wall time: 15.7 s\n"
]
}
],
"source": [
"%%time \n",
"\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('guide_mean_weight', array([[1.0170887, 0.516914 ]], dtype=float32))\n",
"('guide_log_scale_weight', array([[-3.5383508, -3.70919 ]], dtype=float32))\n",
"('guide_mean_bias', array([3.9868238], dtype=float32))\n",
"('guide_log_scale_bias', array([-4.094978], dtype=float32))\n"
]
}
],
"source": [
"for name in pyro.get_param_store().get_all_param_names():\n",
" print( (name, pyro.param(name).data.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "Can't instantiate abstract class TracePosterior with abstract methods _traces",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-94-9647bbbb5a52>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabstract_infer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracePosterior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m: Can't instantiate abstract class TracePosterior with abstract methods _traces"
]
}
],
"source": [
"pyro.infer.abstract_infer.TracePosterior(model, guide, 2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment