Skip to content

Instantly share code, notes, and snippets.

@mlaves
Last active March 8, 2024 05:37
Show Gist options
  • Save mlaves/607d5252325d44fcea02d42179811d2e to your computer and use it in GitHub Desktop.
Save mlaves/607d5252325d44fcea02d42179811d2e to your computer and use it in GitHub Desktop.
Simple Classification Example in Pyro with SVI and MCMC
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple Classification Example in Pyro with SVI and MCMC"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we train a simple Bayesian neural network on the task of binary classification using Pyro.\n",
"A sample of our toy data consists of two features $ X_{i} = [x_{i}, y_{i}]^{T} $ and a class label $ c_{i} $.\n",
"\n",
"We construct a linear model $ z = wX + b $ producing logits and define a categorical distribution for our final output\n",
"\n",
"$$ p(c | z) = \\mathrm{Categorical}(c | z) . $$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from torch.distributions import constraints\n",
"\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"import pyro.optim as optim\n",
"\n",
"pyro.set_rng_seed(1)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# generate toy data\n",
"num_samples = 200\n",
"data_1 = np.random.multivariate_normal((1,1), np.array([[1,0.5],[0.5,1]]), size=num_samples)\n",
"data_2 = np.random.multivariate_normal((5,1), np.array([[2,0.5],[0.5,2]]), size=num_samples)\n",
"data = torch.FloatTensor(np.concatenate([data_1, data_2], axis=0))\n",
"\n",
"y_1 = np.repeat([0], num_samples, axis=0)\n",
"y_2 = np.repeat([1], num_samples, axis=0)\n",
"y = torch.FloatTensor(np.concatenate([y_1, y_2], axis=0))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f890ccf9d90>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEGCAYAAABsLkJ6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO2de5Qc1X3nv7/p6UE9wtZIoMSrGRHJgRVGGFA0ZHHw8S4oQTYPWYY9wmadhOBj1hv8wOYIpHWCZU6yFtYuEK+dzRJwcs6alwLyIENAECB/mBwwkkePCFBwjEEz4BMhMbKNWkzPzN0/amqmuvreqluP7qru+n7O4YiprseveqTf797fU5RSIIQQUjy6shaAEEJINtAAEEJIQaEBIISQgkIDQAghBYUGgBBCCkp31gJE4eSTT1ZLlizJWgxCCGkrdu3a9ZZSaqH/eFsZgCVLlmDnzp1Zi0EIIW2FiLymO04XECGEFBQaAEIIKSg0AIQQUlDaKgZACCFZUKvVMDIyguPHj2ctSiBz5szBwMAAyuWy1fk0AIQQEsLIyAje8573YMmSJRCRrMXRopTC4cOHMTIygqVLl1pdQxcQIYSEcPz4cZx00km5Vf4AICI46aSTIu1SaAAIIcSCPCt/l6gy0gAQYsPercDtZwKb+pw/927NWiJCEkMDQEgYe7cCP/gicPQgAOX8+YMv0giQlvP4449j2bJlOPXUU7F58+bE96MBICSMp24BatX6Y7Wqc5yQFjE5OYnrrrsOjz32GF588UXcd999ePHFFxPdk1lAhIRxdCTacVJ4hoZHsWXHAbwxVsWivgrWr16GtSv6E93zRz/6EU499VS8//3vBwB88pOfxMMPP4wzzjgj9j25AyAkjHkD0Y6TQjM0PIqN2/ZhdKwKBWB0rIqN2/ZhaHg00X1HR0exePHimZ8HBgYwOprsnpkaABHpE5EHReRlEXlJRD6UpTyEaFl1M1Cu1B8rV5zjhPjYsuMAqrXJumPV2iS27DiQ6L66+e1JM5OydgH9BYDHlVL/WUR6APRmLA8hjZy1zvnzqVsct8+8AUf5u8cJ8fDGWDXScVsGBgZw8ODBmZ9HRkawaNGiRPfMzACIyHsBfATA1QCglBoHMJ6VPIQEctY6KnxixaK+CkY1yn5RX0Vztj3nnnsuXnnlFbz66qvo7+/H/fffj3vvvTfRPbN0Ab0fwCEAfyMiwyJyl4jM9Z8kIteKyE4R2Xno0KHWS0kIIRFYv3oZKuVS3bFKuYT1q5clum93dze+/e1vY/Xq1fjABz6AdevWYfny5cnumejqZHQD+C0AX1BKPS8ifwFgA4A/9Z6klLoTwJ0AMDg42OgEI4SQHOFm+6SdBQQAF198MS6++OLE93HJ0gCMABhRSj0//fODcAwAIYS0NWtX9Kei8JtNZi4gpdTPARwUEXdftApAsqoGQggh1mSdBfQFAPdMZwD9FMAfZSwPIYQUhkwNgFJqN4DBLGUghJCiwkpgQggpKDQAhBBSUGgACCGkTbjmmmvwa7/2azjzzDNTuR8NACGEtAlXX301Hn/88dTuRwNACCFp06QJch/5yEewYMGCVO4FZJ8GSgghnYU7Qc4dIuROkANy10+KOwBCCEmTNpogRwNACCFp0kYT5GgACCEkTdpoghwNACGEpEkTJ8h96lOfwoc+9CEcOHAAAwMDuPvuuxPdj0FgQghJkyZOkLvvvvsS38MLDQAhhKRNm0yQowuIEEIKCg0AIYRYoFT+BxJGlZEGgBBCQpgzZw4OHz6cayOglMLhw4cxZ84c62sYAyCEkBAGBgYwMjKCQ4cOZS1KIHPmzMHAgH26KQ0AIYSEUC6XsXTp0qzFSB26gAghpKDQABBCSEGhASCEkIJCA0AIIQWFBoAQQgoKDQAhWdGkqVGE2MI0UEKyoI2mRpHOhQaAtD1Dw6PYsuMA3hirYlFfBetXL8PaFf1ZixVM0NQoGgDSImgASFszNDyKjdv2oVqbBACMjlWxcds+AMi3EWijqVGkc2EMgLQ1W3YcmFH+LtXaJLbsOJCRRJa00dQo0rnQAJC25o2xaqTjuaGJU6MIsYUGgLQ1i/oqkY7nhrPWAZd9C5i3GIA4f172Lfr/SUthDIC0NetXL6uLAQBApVzC+tXLMpTKkjaZGpU6e7c2ZVwiiQ4NAGlr3EBv22UBFRWmv+YKyfOAAz+Dg4Nq586dWYtBCInL7Wc6St/PvMXAl/+59fIUBBHZpZQa9B/PPAYgIiURGRaRR7KWhRDSZJj+misyNwAAvgTgpayFIIS0AKa/5opMDYCIDAC4BMBdWcpBCGkRTH/NFVnvAO4AcCOAKdMJInKtiOwUkZ15n8dJCAmB6a+5IrMsIBG5FMC/KaV2ich/Mp2nlLoTwJ2AEwRukXiEkGZR1PTXHJLlDuB8AGtE5GcA7gdwoYh8L0N5COls2H6a+MjMACilNiqlBpRSSwB8EsDTSqlPZyUPIR2Nm39/9CAANZt/TyNQaLKOARBC4hB1NR/UfpoUllwYAKXUPyqlLs1aDkLagjireWP+/cHoLqGsXEl0YaVOLgwAISQCcVbzgXn2EVxCWbmS6MJqCjQAhLQbcappdfn3fmxcQlm5kujCago0AIS0G1Grad3um7UqIKXge4e1ZMiqlQNbSDQFGgBC2o0o1bR1rhMAatI5t7JAf++wlgxZtXJgC4mmQANASLsRpZrW5DoB4rVkyKqVA1tINAXOAyCkHbGtpjW5SKpvA5ffGX0wi/t52gNdwobENOu5BYfzAEgDQ8OjHLDSKbRD/33/kBjAWd2zR1Bq5HYeAMkXQ8Oj2LhtH0bHqlAARseq2LhtH4aGR1v2/PM3P42lGx7F+ZufbtlzO5Z2cJ0wwyczaABIHVt2HKibrwsA1doktuw40PRnZ218OpJ26L4ZJcOHxWCpwhgAqeONsWqk416Suo6CjA9dUAnIe/fNeQMGN5Uvw4fzhFOHOwBSx6I+fbGQ6bhLGqt3k5EZHavmzy3ElWh62Lqp0nQV8fcHgAaA+Fi/ehkq5fpioUq5hPWrlwVel4bryGRkBMiXW4htCdLF1k2VVjEYf38z0ACQOtau6Mc3Lv8g+vsqEAD9fRV84/IPhrpgkriOXHTGRwD489RaFZMwwqBl+py1zslK2jTm/HnWusZVemW+/tqoxWB5/P1ltCNhDIA0sHZFf4PCD/PvL+qrYFSj7MNcR/7nAqh7ju6eQDTDkjp5aksQlj+f1b2SovP3d5WBUg8wOT57XpyMpjz9/oBMYxs0ACQU17/vunhcNwwwq7TXr15Wdw7Q6DqyCRL7jc/5m59ObFhSxzZo2WzSVBx5C7DqVulTNaeFRc/cZEYqL78/l6AdCQ0AyRqb7Bzd6t2r4G2MiBfXWJh2ABecvlB7fkuK11bd3Fi4BADj7ziKtFUK06Q4tn3W+ey0i4BXnrBTlqZ7ff9zzv+32ggEVTDf9Gqye+t+f1nWRmS4I6EBIKHY+vd1riOXKCmefmOh45mXDxnPDzMuiXGV4WM3AdUjs8erR1q7ag5SEEcPAjvvrv85SDbTvdRkNjuBZq7S89ZWIsMdCYPAJJS4qaFeogSJdcYi6LpMitfOWue4IvwkCSZGDQRGVRBBspkCrGHXNYtmVzDrgs5ZkWG1Ng0ACSVuaqiXKEbEJsDrvS6NDKRYJN26exX+rUuBh6+LlppoM+QlrmxpXReXPFUwNztDJ8N3pQuIhBLm37fBJkjsEpT9o7sujQykWIRt3YOyavxBV68rySUsEFjnytDIYZJZR/Xt4Osq86cby1m4TMKyiWyzjfJQwdyq4HhG78puoKRl2AZqdTEAtx6gX3Od7vxKuWRVv5CIoC6WQHCHS1OXzgbEcVPEkcVPUIfNIHlKPYBSThZO2L3COnu2W+fPduimaoGpGygNAEmVtLJxot4nsxbWptVskEKdtzjCit2jaKKurE1ZQLr7AHoD4k4O0+1QdEowSGGuutnJKlKa+E5eFeqmPjSWIgLWhjkn0AAQa+Iq07CVeKHmDBgVh4uuxtlHM1bOYbsWnYGJogSD3rtcCdih5FShdvgOgDGAAhKkiJOkVAZl4+x87Qjuee71GdXQ9FTNrDHFB2YIUf7uitnr50+jWCjoPqZsGOO7KEdBeuU0nSulYPdUXmf75q1mIGWYBVQwwrp2JkmpDOrm6VX+Ue+bKXEzQOJk6MwgjcrYJuPIRlaTUTp60HxN0Lv4s5VMKY06t4/38zwpVO/3+NQtwNlX5SMbqQnQABSMMAWfJKXSlHVTEjGud0fHqto2z7mYDJaka2Rdap8BKemP63LyTStkb8aRjaymZwLma8LexVsnYEppNF0rpXwpVN33uOdex0DloWYgZWgACkaYgg/K1w9TyqZunpMhcSb/TiQ3k8Hido10V5DbrnV+HvyMflW88monw8bPu79sVMJhxUK2sgatxP3XzKyE5znB2yCXlncnoiuyMsn/ib/Kl0LNY6fQJkIDUDDCCrJMRV8XnL4wVCl7W0kDVmHOOqq1SdywdQ+uf2B3ZmMp64hT6GVaQercCJfeBvSc2HiPqVqjwgkrFrKVNWhH4r2m7j0QbjjCfPh5KuwKIm+dQpsMg8AFI6wgy1T0tWn7fqtePm4/IFMXzzCCdgupV/aGpVVG6dEycy/N+bWqk46pyxoxFWDpFE5QsZCtrKZGdv5rdCthEyYf/t6t9f2SKguAj92aXOmb0ljT6O2Tt06hTYY7gIJhM/Bl7Yp+PLvhQry6+RI8u+FCAMBYtaa9X5yYQX9fBfN7y5FlT7Wy18Znbtujxb9a1mFaQYb59m0Jk9XrluquzOb3Q/TXWK14A1bye7cCQ3/c2Czv4euStVLQ/d6G/jh6Gw0T2oC3BAfJm0ULhsRwB9CmJMmpD+raqSPI9aJTykPDo+gS0a7m5/eW8eyGC606fnqx7T1k/b3YpFXado20WS0bFPq/9p2P9x+936+GnSKuKATJqm07IU5s4pTz9NeEtZcIy4N/6pb6ymGXyfFkfe5NcwL8xO2n39Bew+PIbOWMhBa1oKABaEPi5urHNRpBq3m/UnZlM7lyfnV8AkPDozPP/foP9uPtY/rdhYuu/YOOSN+Lra/XpkdL2GrZ4CIZGh7FuT/7B0iD9ofjMvIT5rLyGwE3jqA1UArY+V3HAOgUeZCryCZtM7BVdQJ/epRr3VhGVNeQ+zvXFYG1aFBLq4bEZOYCEpHFIvKMiLwkIvtF5EtZydJuxMnVD8uscTN8lmx4FL+58e+xxJPpE+R62bLjQF0gOKyVc21Kzci5dkU/ghKEyl2CO648B89uuNBK+d+wdU/o9+K+58jUSfobxfH1Bl0TEOzcsuMAFuEt/XV+xWPjsjKdY1zJK3N2iz/1000ftQ3eBn4nCfzpUa6tzE82/L3ZAeEgF0+LgtFZxgAmANyglPoAgPMAXCciZ2QoT9sQx+8eZDS8xgGYDcS6RuKC0xei3KVbpjYaEptArfccU2wBQIN72kTYrsN9nvc9n5o6R298TK6XoH+sJv/75X8dmDceKahtk55oOicw9z9Aocykcx4FvnbE+XPVzc5zwvzSq252Zvj6KfUkK/rSfdfurGAv7jlJUjpt4jNx/fRhBj2t2FAImRkApdSbSqkfT///LwG8BKADewIkx59/32cIoAat1IOqdL/+g8YMH5dqbRKP7HkzUBm76ZthuwUbOb3UJpVV6mfYrsN9nve8S7uei+Z6CfrHGjPFMVJQ22ZFGDTVy/QLjKJQohTGnbUOWPuXnmAznP//+HeC20iHKVPdd732L537+r//KBlWOmwC63F3GGEGvUVDYnIRAxCRJQBWAHhe89m1AK4FgFNOOaWlcuUBnV+73CUolwS1ydklbFiQNKjHfpgPPnCVPs2kUti4bR+uWNmPh3aNGhWyX875veXA50fdUQQ9zz1vTdcPsUB+pb9ApxxsA8YRfbPrVy/D5FAXujHV+KF/1W6TnhjUf6jcC9TeqT8WdTVu65f2+91tUj+jBD1N37UuOJ8kpTMsCSCJnz7MoLdobGXmaaAiciKAhwBcr5T6hf9zpdSdSqlBpdTgwoULG2/Q4ehWt7Uphbk93YGpnH50BV5pU61N4pmXD9WlmfZVypjfWzbK+bXLlqNcMm8vbCqQg1pQeJ/nnndj91b96h/QKwdLf6xV+wrPKnftP67G4ZN/W18st/Lq+p9tVoRBmUN+5Q8gMACjw7YfUZxVcTMqcNNYRQeNjkzip7dx8bRgbGWmOwARKcNR/vcopbZlKUteMa1uj1Zr2P01+1RBVwle/8DuSM8vdznpnFOWumJ0rBop02jtin7sfO0Ivvfc69pnuxXIQZk9puI2v7Fxz1skhsAroFcOFqtv/05t5S+exLlDn4V6+DDE1HP/6EH8evktYOl/BH72Q8dVIyVH+V96W/2zbFaEOvdVEG7Fsa1isdmFxF0VNyPo2exVdJKisZx0Gc3MAIiIALgbwEtKqdvCzi8qaY47XLuiH1t2HAit0BVxFod9lTLeGZ/AlMZDYbwWmLm/bXrqMy8f0h4/cU43nnn5kDZ4vWn7/jpDc8XKfjzz8qFAw+P+/G8PL8T7oHlmZYFeOVj8Y/Xu1NZ0/RCby3ehV8adD91VcLemH36tChz5qRNkDSPMzRRHWfqvCUqbtFFacRV5sypwmzlqMYkSb5GLJ4wsXUDnA/h9ABeKyO7p/y7OUJ5cksZA9rD7+Vk0r4I7rjwHvzw+URdncCmJ4NPnnaLNDIrT8tm0yxk7VjN/Vq3VpbQ+tGsUd5zxCl799Zvw7PHLsfYfV2vdDmtX9ON9l/8PvWvgY7fqBdQFHs++qi4bZvAXT86cfmP31lnl71Kr6qdqAXol7A+G2gRI46awepu+bbs2WbBb18nUPR4kf4uCnqmStL9RC1w8YXAiWBuQ9iQt935hg9dNgVwBcPuV52D9g3u0BkJ3/qubLzF+buob5DaVs+kptKbrh7i1525U8O7swaCJWXEKhLzX+lZ+VZyAm8Y/g+1TH8ZPT7gKhqxZPf6xj/5VpWkm79lX1Y98PO0iYPj/OdW2Nrj32HNvSG8gy+lXbvsHf2VuVwlAV/3xUo/TCK/6dvr9fEgDHAlJGojbsC2KYnbPX796mfUUMsAxQFes7Meje98MzVICgB/2fBEDXRrffjNG9xnGBI6qk3H+u98yy1JZAExUg8c6Wg+LBxr6rZZ6gMkJQJdV5DJvcb2C9TZrC3qOzbhGo+wRx1+S1DEZgMyzgEh2xMkMEgAXnL7QuojJ1Er6yw/snqk2BtDQoM5NJ/UrfzeryI8xsGvjF49azGO45yI5jKtP/BF65XijunNdTGEug0h+fN9TJscRqvy9LgfAQvnD7NbxY5TdYpHZwT3380wu6gBIcxgaHq3rtdNXKWPTmuUzK29/62dTAzcvCsBDu0bRF5C/XxLBlFIzK31dKqt/NvAVK+tdWo/ufVPrgpp7Qrc26+dNnIx+XVuFML94nKZbhoClVOZj08T/BcRnHP1tkINWuaGzhOPi6WjpbfiWJkHzgMPmCQDOtZv66P5pIaE7ABH5vIhYLgFIXhgaHsX6B/fUKemxag3r/25PwxAXt/XzlKU7sFqbDEwhn1JqppX02hX9obuFam0S33vu9bodgsm4vDFW1ba0fmPljfGCiHHyz0359pPv6n3pPXPtlZkuGFrq0bRViBJkABo6Wu7dar/bqB5p3B3pdk2mQO7KqyPMR07YzplEwsYF9D4AL4jIVhH56HT6Jsk5W3Yc0AZoa1MK1z+wO1JBlY6j1Rr6KnYtKeYZzouDe2//zIJz1/zXeBkZcdIWTfn245piq7B7+dFllnz8O067A++xwWviGwrXwEXKGvIo5ke+oi/2AvS/g0tvqz9eWaDvE6STkTQVqyDwtNK/CMAfARgEsBXA3Uqpf22uePV0chA47UyfpRseDfW8+oulovTodwO7/vMFwH857xT82doPztzTNlsoDDeUaNse2gpT4DIoeLypD5GGXTYjEA3YTcYyupMEuPzO8Olg2ksNLp0o7+mV3fhdWgafSSimILBVDEAppUTk5wB+DqeL53wAD4rIk0qpG9MVtXjE7e8fRFDvHxdvjr5rfOZVyphT7grMvHHrENwq3nuee33mn7AbIxj8jQUzhWdxlX9fpYy5J3TPvIc/bgDE/35cXvjNL+DMH/9pY/pokOsoTLHWKTSJPtzFNkXVpieO0cANNBYjVeYDE+/q20Z4Mfnzo+50wjKfOnQMY56wiQF8UUR2AfgmgGcBfFAp9d8ArARwRZPlKwRR+/vb9JxZv3pZYI8dl9GxKq5/YPeM/32sWsPx2hQ+fd4p2gyhvkq5btfwyJ43A4u/4s7xFQCb1izH+tXLtN7uNIbEDw2P4g9e+A3cNP4ZjEydjCklGFUn44UPfj3YdRQYV1Cod7soJ88+SZvgbdc6bpc4hBVYeYuRbnoV+OobTtvnTUfNA+RN7aXjKux2LALrEEJdQCJyCxx3z2uazz6glHqpWcL5aQcXUBxXjsldoyug0rlpyl2CE+d0Y+xYre6ZfzK0D/c8/3rknl9AeO6+K0tQbyEBAjOLLLLDAwkrMAsjqADNnYVs5Nal+hTKpO6RoFz6y++MlxkTtKMI+0zX6kBXPObP449aaOc930079RaJMSMoEbFdQEopoxlupfJvB+K6cqL0+zF1B3VdNu4zd752BA/tGo2l/IHZbJsguTdt3x94DwUEppUqOCmjYamngFPpe2P3ViySt/CGOhnfnFiHXe/9vdDrgog0WMev0JZ/Qq8ETf50W/dIUC69TeM2k+I1VUMHpcAG9asxzRK2ua8O93ktmoVLHFgHkCJBrpywVs26Slhdvx8bl0q1Non7nj9opVhNzKuUcc7Xn5iZBTC/t4yvXbYcAKwaytny3ko3fnV8ArWAdqP+5moD8hY2l+/C/jOWAAhZqQdgbXh1SmnPvY2tGNxpWUn82UHxhTAjElV5mlJgv/85x+3kvpN/5xK2uk/SJ79Fs3CJAyuBUyTOqEbA2R1csbIfpekM25IIrlipX33bpmomUf4A8IvjtbpBMG8fq+ErW3dj/YN7UlP+7n0hTmzBFLHQNVfrlXGc+6//O9GzrRvtmZTSK080NvMyjSwcf8eu0njVzYg9vStqTUPgBLGIc4fTmmfbolm4xIEGIEVMyjlMaQ8Nj+KhXaMzSntSKTy0a9QY3G32YJcTuru0/f+nFFJJ5/RTm1SYe0I3Xt18yUyfIS+J2jwEoCso0w7WiaKU/Hn8lQVOf+3qEVgVOZ21zsnx9xsBm6BokJy6wi2bXUmtCmz77Ow1NkYmyTzbFs3CJQ40ACkSt3VzlCwgr9IC9GvFpJV6705EGACQEm+MVTE0PIpj4xONn6mT9RdFnWer6ffjLyjTuuqiKiVvZk3P3MbunGFFTpfe5gR8oxa1meSpzNev2k+7yL5C173Gxj2VJKuHGUEthQYgJdzsn2ptcsaVYzOqEYjuOlq7on/G2PjX4/N7y4kya7Kir7eMjdv2aesPvjmxDuNyQv3BKEohyfBuIJlSiuvSiNIr3jVuRw9Cu3MA9Kv2XX/rxDFcQ2NK7/ReY5MCmqRPftIe+yQSDAKngD/7Z1KpumKpMMKCkbrUUt2uAQB6e7rR29Md6qcviaCnW1Ct2a/2uwQQEUxazIfsEliPkayUS1AKxgrk7VMfxoJSDzbNfSher3iT2+Kxm+wLrtz7RH1+syZduTSkarp1CMpRnqtudgK6OtSkMz/g49/RZ+CYrvFnO7nGMMmMBS/NnOJF6uA8gBSwzSc31QiY+uF/43KnnYKu3UJA8Txuv/Kc0JYOAgR29PQTRaFHuYebXfTlB3YH7lwS5fzbtm5oRk96Uy59Ws+xaWURNmOgssApAnPlNWUyufedyXbyKPrXnwN2fhd13zN7/OcGzgNoIjYuHFfJeztebty2D0PDo4HByKBWyjoW9VUa4gSm88YslT+QXPn77zG3p4Q7rjwHwzdfhLUr+kMD5XFmIM9gu9puRgOyZrs0bFxMOheWF29Bm+t6uvyvzW4vv3sKaFT+ABu6tQF0AaWATT55WI2AqegqSisFU8DZv2Nwz/POCmg174xPYudrR2beWVcL4ZJkBjIA/fBuE81IN2ymS8PGxeQ+e9tn7e8bxe311C0wLkvc7zMt9xBJFe4AUsAm+ydujYDtyte7a/DuNoDGf5pzyl3Y+doR/Op4Y8ZNK7nv+VnF5a+FcLENpGtxg6PbrgW6K46rw12FVxbor8ki3TDKRDL/ubpMHl2A+qx15nc2HbcNRAcZTXfofJIgPGka3AGkgH+ylq5vTpR2D4Dd4HaX+b3luliDKUDs8vaxGr733Ouh92023mI1fy0EgEiB9Ab8vvfqEUcxuv10TL75VqcbBlXvAvWr5tMuqm8/EVSRrFPWH7tVP7S9eqR+UlhUjNXLMhsvYHVvLqEBSImwvjlR2j1E6csPoKHfT9wOnFGx7eMTdL1L3DYaRsKUjo2LY+/W+qHp/tGOaRCUoeQdIH/0oNnP7lYkh1H3zm7KqG9SmPc8L0EuHK2LTYClHwkeOs/q3syhC6hFWFedInwF72esWqurGk4UMLXkZ5svwf9ad3aie0wqNdPOOq6LzIhNcDTIxbF3K/DwdfXKq3rEWUGn6bowyVk9oolZmPzsB+0H2rvvPG9x4/1MQdswF44u0D14DfDaPwUPnWd1b+ZwB9BCwnYJLnGUnrfraFBANU3WrujHpu3763oGRcXNhppXKWvvE9uYJc2/f+qWxgpewHGfpOm6SG0IvApfxXuJUqBm48LxB7pvP7PR1eSF1b25gDuAHGJSevN7y8Y+QN7WEe5uwzSzNymVctfMQBoRZx5BEqq1SYggVhsNI0lbCgS5J9JyXezdqp8jXOoBJOY/TdvUyyjtLeJUM4d9R6wPyAU0ADnElFX0tcuWzxSH6fDuHNau6MemNcubYgQmptRMPcPbx2qoTSlUys5fJW9HU++fYYwdq1m7yKxImn8ftFNIw3XhulX8LpKeuU5QR+kqtC0NrY2BimIg4zRoC/xsMZV/TqALKIeEZRWZsoO8O4c/GdpXN6s3DqaKY11H0OO1Kdxx5TkNCsp9aewAABNESURBVHvphketnuUWsBkVfpw88iT596tuBoY+B0xp3GhRZ/y6eN9BuvSTw2rH9celBKy8Gth5d/hzbAxUlDx/XZA3bDe16mZ9xlGph66fHMFWEG2Ct41EX29ZO0TFO7QlrLVCM9CNUjS1yTBdr037bHY7BROmsY+24x292PTZCUScYLVJJpdmfS9xDHArsqiIFaZWEDQAbcDQ8CjW/92eBoWvW6FXyiXMKXdlUuFrO8M4CLcHUp0RsOl34yWtqlNjD6FpZRyEX4bxd4IV98ytQ2YKaw2Jr/kbFSzxEXsmMMmeTdv3a0cm6lRTtTbZ9OwfE329jfEGvzsrbLmhzf2PEoRMc6Zs3EwinQw2BA1cd90mSTqTEuKDQeA2IEmaZSv51fEJ7RQzd+jK7VeeYxXGbEiDjRKEjDoWMYi4mUQ6GUxICXVB6ktvCw9eR5kV0E5EaYlBUoEGoAPpqzSmiwqA3nI6v+7+voo2u6g2pXDD1j1YuuHRmQIvL1t2HLCKSzSkwVoo4qHhUZy/+WlMjcUcqK4jbiaR7bPKFeATf+W0pwCcnkW3n+n8fycq+CDYLygTMjUAIvJRETkgIj8RkQ1ZypJn5mtcKyYq5RI2rVnekFJ5+5Xn4FiE4S9B97/g9IXGXcmkUg3trl1sCty0uf8hitjb/C6V8ZH+Z0dVxsbRjAsa3+H15xzF3yrFl8Uq2+aZae7ciDWZxQBEpATgOwB+D8AIgBdEZLtS6sWsZGoWpkEwtnztsuVY/+Ae40B2163iv7f/GTds3RO5d8/83jJ6e7pnZL/g9IV4aFejm0eH359vaohXEsGUUsHfTUBKp7d1xjcn1mFz+S70iqeKt9VVp6ddpE/XXP4Jx8UDNGbIePFX2aYV1E4zPpL2M+OOziSJyDII/NsAfqKU+ikAiMj9AD4OoG0MgI1i92fBuCtjoFFBm3DPu/6B3cZzbKZlxWncphTq3uv8zU9HCjJ7V/2mhniJCr58z9g+9WGgBtzYvRWL5DC6+nwK06tMK/OdY9W30w2mvvJE8HGblNCjI3ojEVVpP/IVZ/avLrMIaH5XTttxnJX5hpRb9gtqJlkagH4AXoftCID/4D9JRK4FcC0AnHLKKa2RzAJbxW7qcrlp+34r4+E9py9hv5x+wwo8iLFqre69ovYp8spm0zY7Dv6dxfapD2P7+IeduoQve+oSdC2iXdJaDe/das76cVezNkHiynyzkbBV2o98xa5wrJmr7KBmd+73f/Qg0FV2isQmM9y5FZAsYwC6hJCGJapS6k6l1KBSanDhwoUtEMuOoPbFXkwKc6xa046HdNGNkHxnfKKh706Ufjm6FhM2eN8rSnM2nWxuRtCrmy/BsxsuTKz8AbuBPADCFW9Sn7NrYEy4q9lQhSuz8pg4ejDch7/rb0Oe45MrjDjxA9t7T9WAnhObNzqTaMlyBzACYLHn5wEAb2QkS2Rs2xeb/N5+qrVJ3LB1DwAYZwHXJlWDTz5sBe3fRVyxsj/WMBj3vYI6jZZLgrk93TharaW2urfBemdhs9JNshoOMjDe1WxYB9DBa6Z7/4fg37X4YwUmt49JriDixg+ijOOsvj07nJ60hMwqgUWkG8C/AFgFYBTACwCuUkrtN12Tp0pgU4sDfzuEqD15yl2CE+d0Gyt5vdW2YTEIXRVupVzCCd1dkWsLvO/lnVbmDoUxtnEwkDQwHgtTRbGXOG0eXIyVwwAGPzM7tcvk7wacTKGbXrWT1cWtAI7UakJmp4zZTBOLWo3txbYqOsl3TwLJXSWwUmpCRD4PYAeAEoDvBin/vGEz4csdc+hVCQKnnbIpJbM2pQLbOLguGJsYhMlNNafchUq5FKk9g/e9bOcamEgjMB6LMCWZ1Ofc06tv71xZUF/dWz0CdJUaG82VK06vHBtZvRwdiVZ8NvgZJxspyqo+SZaOP4Mr63GcHFA/Q6Z1AEqpv1dK/Xul1G8qpf48S1miYjPhS6eAFYATyqVYvnivIraJQRjjD77Wy32VMsql+tiC+1NYW2a3AMtU/KXDNn6SOv56gsqC+kHxSXzOj3xFr/wBYPLdRuU8NamvC/AOWDn7qulKYTh/lufq7x/mUvLew1X+QLTc+zgtoU0kbdWdBBac1cFeQAkIWwkHKeDbrzwnUl6+38ViE4MIGkTvlz2OSybuSj718Y9RSNIiOoiggKvJMAT5vPdudXYNrh9fTQI1w31MhgEIdqtEWdXHaQkdRLN+D2FwQH0dNAApoVOgYQoYgFWnTF2b5aB7u0QZRB/HrRN3kLuN7G2HTcDVT9DqOYpL562XzZ8FKegoze46pQkdC87qYC+gFNClbG7ctg8XnL4wMD3R70bSuWJMCluX+ijTz3ZdMUFuqjiuGz9xV/LWaZvthAS49CoLojeVS0shhWXolHrqjwUNbOmEJnRpurI6AO4AUsC0En7m5UP4xuUfDHStxHXFuMe8Q9ldZ5LfFdOM6mQg/kq+WQVhmWKc1tU1G9iNsnpObVh8CH4XZBvNB4lF2q6sNocGIAWCVsJRXStRz393Qp9NFOSKieu68RPFxeQnaSZR7nADq7v+Znaeb3kucNkd9YFdWyKndWqoLAj+/KlbGkc2TtU62x/eKa6slKABSIGsfNo6Re4lqosmahC2I1fySbj0tllDkBS/oqrMB8Z/Vd8qwaXU42QVeeMQpZ7ZnYeJovrDswpA5xAagBRIshKOitdFFLZZNxmgNA1W267kbXLBs84X1+XPP3WL4xpyR0e6RWBA9PdhA7bCQwOQAq1aCUeZrxtkgFppsHKJTQFUFq2TwwhbuQZ9pnufUo/ThM3rBiqwP7yIcCh8G2FqP+Gnr1LGpjXLI/UIyoPrpmUy2bQ1SNL6wE+UnUSzdh2m96ksAHrm0h/e4eSuFQSJTpCPXtA4ECaIZrpuWllUFgsb33da/vEoO4lm7jqMbZnZgK3IsA6gjZinmcMLOCv+NNsrJ8FUExFWZ9DS9hA2ueBp5YtHabfQzLGIzH8nGmgA2gjRTVAIOJ4FcRV5S9tDWAyZtzrHBuNOQtPPP2zXkWSeb1rv4yWL+cIkVWgA2ogxQ5dQ0/EsiKvIgzKWUkfXjOzsq5yVtqvMgHQalgWtsP2NyIzD5OcDty4Ftn02fhOztBuwmZqqPfIVGoU2gjGANqIdeujElbHlmUnejBqT7/2yb8XvT+9N2YTAOCfA24hMV/xV6gHe/WVjwZb/WhvSzH83uat2fhcz75qHzCkSCHcAbUQ79NCJK6NNe+2mkbbvvW51DDgKMcBP57p4dKv0nhP1yt9/bVKiunOMz/UZurRiGKQpcAfQRrRD5W0SGTMrKku7IlbbyVPNFm/58bp+/Kv0TX3Bz0ojiBsn+yhKryL/95h1gR2ZoeMNQB7z3U3YyGpq7pand2y76uAobZFtMBkONekEXqM0IgtStGkVbcXpka/tVWRwdXm/xzwW2BWYjnYBxU1JzIK4sqb9jmm0ic4FUVwaaWfIGFMuF0cPxK662anW9VNZMHtt0mycODsgnbtq8Jrw77GZqa4mmK1kpKN3AGl1vbQh6So8rqxpvmNms3rTJuoqM+0OkUEth+MEYv15vm6jN1f5J11Rp7EDGn8H2P99Rw5/nyKvHEncbXFcR9xxBNLRO4BW5ZansQqPK2ua75jZrN60ibPKTHPYSZopl0/d0tgBdHJ89l3SWFHH2QH500CrR2Yby7muLp2CjluQFneWbxY7jjaiow1Aq3LL01CccWVN8x0zndWbJnloc5yWQQl7F5t3DXOBxDFYYSMrTUo2rrstriLPw9+FHNPRBqBVaZNpKM64sqb5ji0txmomndT2IOxdwj63XTlHNVg2ClR3TtzdUVxF3kl/F5pARxuAVuWWp6E4o8rqBmu//MBuzCl3oa9STvyO7VBnYEUz2h7EIY3g42kXBR8Pe9dmuUBsFKjXCHm/ByD67iiuIs/L34Wc0tFBYKA1KYlpVbEGyeoNMs+rlPHO+ARqk07K3dvHaqiUS7j9ynMSvWs71BlYkYexf2kFH195Ivh42Ls2ywUSNrLSVbJpfQ9xZ/nm4e9CjuE8gJRoZi6+7SCY/r4Knt1wYSrPJDHwZqlIl6HoK+I8gU190LeREGcFHUaacw38eN+3Mt85Vn27XslmNVeB1MF5AE2mmTuNsNm/Lm0XrO0k/CtdnfIHoq+8k6Zoxl0522CT0prmDoSzfFOno2MAnYKtYm+7YG0nEZYV42JS3KZ4QVIfdtpdQKPCIGyu4Q6gDTB12PTSlsHaTsJmRWtS3DZ+8iSujyxXzs3cgZDE0AC0Abogc7lLcOKcbowdq7VvsLaTMLlqpASoqWDFHdaLp51dHwzC5hoGgdsEU5A5b43gCot/FQ84K10bd0vSQC8hITAI3OaYuoB2RO+eTiDJSjftbqSEWEID0Ma0stkdsSCuq4Z+cpIRmWQBicgWEXlZRPaKyPdFJGTqBdHRMb17ik7WmTqksGS1A3gSwEal1ISI3ApgI4CbMpKlbWmHGcHEklYHellURZDRDkAp9YRSamL6x+cA0NkZg47p3UNaS9zWyqTjyEMh2DUAHjN9KCLXishOEdl56NChFoqVfzIdpE7aF/bIJ9M0LQ1URP4BwPs0H31VKfXw9DlfBTAI4HJlIUiR00AJSQ2mnRaOlqeBKqV+N0SgPwRwKYBVNsqfEJISTDsl02SVBfRROEHfNUqpY1nIQEhhYY98Mk1WMYBvA3gPgCdFZLeI/FVGchDSWtIYEpMUpp2SaTJJA1VKnZrFcwnJlLSGo6RBO/cXIqmRhywgQooBs29IzqABIKRVNGs8IyExoQEgpFVwOArJGTQAhLQKZt+QnEEDQEirYPYNyRlsB01IK2H2DckR3AEQQkhBoQEghJCCQgNACCEFhQaAEEIKCg0AIYQUFBoAQggpKDQAhBBSUGgACCGkoNAAEEJIQaEBIISQgkIDQLIlDxOyCCko7AVEsiNPE7IIKSDcAZDs4IQsQjKFBoBkBydkEZIpNAAkOzghi5BMoQEg2cEJWYRkCg0AyQ5OyCIkU5gFRLKFE7IIyQzuAAghpKDQABBCSEGhASCEkIJCA0AIIQWFBoAQQgoKDQAhhBQUGgBCCCkoopTKWgZrROQQgNcS3OJkAG+lJE6e4Hu1D534TgDfK+/8hlJqof9gWxmApIjITqXUYNZypA3fq33oxHcC+F7tCl1AhBBSUGgACCGkoBTNANyZtQBNgu/VPnTiOwF8r7akUDEAQgghsxRtB0AIIWQaGgBCCCkohTMAIrJFRF4Wkb0i8n0R6ctapiSIyEdF5ICI/ERENmQtT1JEZLGIPCMiL4nIfhH5UtYypYmIlERkWEQeyVqWtBCRPhF5cPrf1Usi8qGsZUqKiHx5+u/fP4vIfSIyJ2uZmkHhDACAJwGcqZQ6C8C/ANiYsTyxEZESgO8A+BiAMwB8SkTOyFaqxEwAuEEp9QEA5wG4rgPeycuXALyUtRAp8xcAHldKnQ7gbLT5+4lIP4AvAhhUSp0JoATgk9lK1RwKZwCUUk8opSamf3wOQDtPIP9tAD9RSv1UKTUO4H4AH89YpkQopd5USv14+v9/CUeZ9GcrVTqIyACASwDclbUsaSEi7wXwEQB3A4BSalwpNZatVKnQDaAiIt0AegG8kbE8TaFwBsDHNQAey1qIBPQDOOj5eQQdoiwBQESWAFgB4PlsJUmNOwDcCGAqa0FS5P0ADgH4m2nX1l0iMjdroZKglBoF8D8BvA7gTQBHlVJPZCtVc+hIAyAi/zDtu/P/93HPOV+F4264JztJEyOaYx2R1ysiJwJ4CMD1SqlfZC1PUkTkUgD/ppTalbUsKdMN4LcA/B+l1AoA7wBo61iUiMyHs5NeCmARgLki8ulspWoOHTkUXin1u0Gfi8gfArgUwCrV3oUQIwAWe34eQAdsVUWkDEf536OU2pa1PClxPoA1InIxgDkA3isi31NKtbtiGQEwopRyd2kPos0NAIDfBfCqUuoQAIjINgC/A+B7mUrVBDpyBxCEiHwUwE0A1iiljmUtT0JeAHCaiCwVkR44gartGcuUCBEROP7kl5RSt2UtT1oopTYqpQaUUkvg/J6e7gDlD6XUzwEcFJFl04dWAXgxQ5HS4HUA54lI7/Tfx1Vo88C2iY7cAYTwbQAnAHjS+d3iOaXU57IVKR5KqQkR+TyAHXAyFb6rlNqfsVhJOR/A7wPYJyK7p4/9d6XU32coEwnmCwDumV6E/BTAH2UsTyKUUs+LyIMAfgzHTTyMDm0JwVYQhBBSUArnAiKEEOJAA0AIIQWFBoAQQgoKDQAhhBQUGgBCCCkoNACEEFJQaAAIIaSg0AAQkgAROXd6tsQcEZk73UP+zKzlIsQGFoIRkhAR+TM4/X0qcPrifCNjkQixggaAkIRMt0B4AcBxAL+jlJrMWCRCrKALiJDkLABwIoD3wNkJENIWcAdASEJEZDucaWxLAfw7pdTnMxaJECuK2A2UkNQQkT8AMKGUund6RvM/iciFSqmns5aNkDC4AyCEkILCGAAhhBQUGgBCCCkoNACEEFJQaAAIIaSg0AAQQkhBoQEghJCCQgNACCEF5f8DMrVDWwbpZJAAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(data[:num_samples,0], data[:num_samples,1], label=\"0\")\n",
"plt.scatter(data[num_samples:,0], data[num_samples:,1], label=\"1\")\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pyro.nn import PyroSample\n",
"from pyro.nn import PyroModule\n",
"\n",
"\n",
"class BayesianNet(PyroModule):\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.linear = PyroModule[nn.Linear](in_features, out_features)\n",
" self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))\n",
" self.linear.bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1))\n",
"\n",
" def forward(self, x, y=None):\n",
" logits = self.linear(x)\n",
" with pyro.plate(\"data\", x.shape[0]):\n",
" obs = pyro.sample(\"obs\", dist.Categorical(logits=logits), obs=y)\n",
" return logits"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from pyro.infer.autoguide import AutoDiagonalNormal, AutoMultivariateNormal\n",
"\n",
"\n",
"model = BayesianNet(2, 2)\n",
"guide = AutoDiagonalNormal(model) # restrict variational distribution to diagonal Gaussian\n",
"#guide = AutoMultivariateNormal(model)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from pyro.infer import SVI, Trace_ELBO\n",
"\n",
"\n",
"adam = pyro.optim.Adam({\"lr\": 0.03})\n",
"svi = SVI(model, guide, adam, loss=Trace_ELBO())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[iteration 0001] loss: 1.1963\n",
"[iteration 0101] loss: 0.1787\n",
"[iteration 0201] loss: 0.1413\n",
"[iteration 0301] loss: 0.1327\n",
"[iteration 0401] loss: 0.1425\n",
"[iteration 0501] loss: 0.1311\n",
"[iteration 0601] loss: 0.1340\n",
"[iteration 0701] loss: 0.1330\n",
"[iteration 0801] loss: 0.1259\n",
"[iteration 0901] loss: 0.1299\n"
]
}
],
"source": [
"pyro.clear_param_store()\n",
"\n",
"for j in range(1000):\n",
" # calculate the loss and take a gradient step\n",
" loss = svi.step(data, y)\n",
" if j % 100 == 0:\n",
" print(\"[iteration %04d] loss: %.4f\" % (j + 1, loss / len(data)))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from pyro.infer import Predictive\n",
"\n",
"\n",
"guide.requires_grad_(False)\n",
"predictive = Predictive(model, guide=guide, num_samples=1000,\n",
" return_sites=(\"linear.weight\", \"obs\", \"_RETURN\"))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# slice through x\n",
"\n",
"x_ = [[x,1] for x in np.linspace(0, 7)]\n",
"x_svi = torch.FloatTensor(x_)\n",
"\n",
"samples_svi = predictive(x_svi)\n",
"\n",
"y_new_svi = []\n",
"\n",
"for i in range(50):\n",
" y_ = torch.bincount(samples_svi['obs'][:,i], minlength=2).float()/samples_svi['obs'].shape[0]\n",
" y_new_svi.append(y_.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"y_new_svi_np = np.array(y_new_svi)\n",
"x_svi_np = np.array(x_svi)\n",
"\n",
"plt.plot(x_svi_np[:,0], y_new_svi_np[:,0])\n",
"plt.plot(x_svi_np[:,0], y_new_svi_np[:,1])\n",
"plt.xlabel('x')\n",
"plt.ylabel('p')\n",
"plt.title('SVI posterior predictive');"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sample: 100%|██████████| 1100/1100 [01:04, 17.06it/s, step size=1.22e-01, acc. prob=0.895]\n"
]
}
],
"source": [
"from pyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"\n",
"nuts_kernel = NUTS(model)\n",
"\n",
"mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=100)\n",
"mcmc.run(data, y)\n",
"\n",
"hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" mean std median 5.0% 95.0% n_eff r_hat\n",
"linear.weight[0,0] -1.27 0.77 -1.31 -2.60 -0.08 317.19 1.00\n",
"linear.weight[0,1] 0.34 0.74 0.36 -0.95 1.48 533.85 1.00\n",
"linear.weight[1,0] 1.35 0.76 1.31 0.03 2.45 302.17 1.00\n",
"linear.weight[1,1] -0.34 0.73 -0.35 -1.49 0.85 530.33 1.00\n",
" linear.bias[0] 3.01 0.81 2.98 1.50 4.24 680.84 1.00\n",
" linear.bias[1] -2.98 0.78 -2.95 -4.16 -1.62 794.84 1.01\n",
"\n",
"Number of divergences: 0\n"
]
}
],
"source": [
"mcmc.summary()\n",
"posterior_samples = mcmc.get_samples()\n",
"posterior_predictive = Predictive(model, posterior_samples, return_sites=(\"linear.weight\", \"obs\", \"_RETURN\"))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# slice through x\n",
"\n",
"x_ = [[x,1] for x in np.linspace(0, 7)]\n",
"x_mcmc = torch.FloatTensor(x_)\n",
"\n",
"samples_mcmc = posterior_predictive(x_mcmc)\n",
"\n",
"y_new_mcmc = []\n",
"\n",
"for i in range(50):\n",
" y_ = torch.bincount(samples_mcmc['obs'][:,i], minlength=2).float()/samples_mcmc['obs'].shape[0]\n",
" y_new_mcmc.append(y_.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"y_new_mcmc_np = np.array(y_new_mcmc)\n",
"x_mcmc_np = np.array(x_mcmc)\n",
"\n",
"plt.plot(x_mcmc_np[:,0], y_new_mcmc_np[:,0])\n",
"plt.plot(x_mcmc_np[:,0], y_new_mcmc_np[:,1])\n",
"plt.xlabel('x')\n",
"plt.ylabel('p')\n",
"plt.title('MCMC posterior predictive');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing Posterior Distributions"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# visualize posterior on parameters\n",
"\n",
"fig, ax = plt.subplots(1,4)\n",
"sns.distplot(samples_mcmc['linear.weight'][:,0,0,0], kde_kws={\"label\": \"MCMC\"}, ax=ax[0])\n",
"sns.distplot(samples_svi['linear.weight'][:, 0,0,0], kde_kws={\"label\": \"SVI\"}, ax=ax[0])\n",
"\n",
"sns.distplot(samples_mcmc['linear.weight'][:,0,0,1], ax=ax[1])\n",
"sns.distplot(samples_svi['linear.weight'][:, 0,0,1], ax=ax[1])\n",
"\n",
"sns.distplot(samples_mcmc['linear.weight'][:,0,1,0], ax=ax[2])\n",
"sns.distplot(samples_svi['linear.weight'][:, 0,1,0], ax=ax[2])\n",
"\n",
"sns.distplot(samples_mcmc['linear.weight'][:,0,1,1], ax=ax[3])\n",
"sns.distplot(samples_svi['linear.weight'][:, 0,1,1], ax=ax[3]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that our variational family has diagonal covariance.\n",
"The SVI approximation of the posterior of the model parameters is over-confident compared to the true posterior from MCMC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment