Skip to content

Instantly share code, notes, and snippets.

@Sayam753
Last active May 1, 2021 07:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sayam753/e449f44c0d4e1d9070805cf6728a1b1d to your computer and use it in GitHub Desktop.
Save Sayam753/e449f44c0d4e1d9070805cf6728a1b1d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d22371e8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.6.0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import pyro\n",
"import torch\n",
"import torch.nn as nn\n",
"import pyro.distributions as dist\n",
"from pyro.contrib.easyguide import easy_guide, EasyGuide\n",
"from pyro.nn import PyroModule, PyroSample, PyroParam\n",
"from pyro.distributions import constraints\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"torch.manual_seed(42)\n",
"pyro.set_rng_seed(42)\n",
"pyro.__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "88d4cef4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([100, 1]) torch.Size([100])\n"
]
}
],
"source": [
"# Setup data\n",
"x_data = np.linspace(0, 10, 100)\n",
"ep = 0.5 * np.random.randn(x_data.shape[0])\n",
"y_data = 5*x_data + 0.1 + ep\n",
"x_data = x_data[:, None]\n",
"y_data = y_data\n",
"x_data = torch.tensor(x_data).type(torch.float32)\n",
"y_data = torch.tensor(y_data).type(torch.float32)\n",
"print(x_data.shape, y_data.shape)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e5ee1da1",
"metadata": {},
"outputs": [],
"source": [
"class BayesianRegression(PyroModule):\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.linear = PyroModule[nn.Linear](in_features, out_features)\n",
" self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))\n",
" self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))\n",
"\n",
" def forward(self, x, full_size, y=None):\n",
" sigma = pyro.sample(\"sigma\", dist.Uniform(0., 10.))\n",
" mean = self.linear(x).squeeze(-1)\n",
" # since I am passing the x and y from a batch, I think I need to pass x.shape[0] to\n",
" # subsample_size as a proxy for batch_size. And do not use the random indices from pyro.plate\n",
" # context manager because I already have the data at hand. Am I right thinking this way?\n",
" with pyro.plate(\"data\", size=full_size, subsample_size=x.shape[0]):\n",
" obs = pyro.sample(\"obs\", dist.Normal(mean, sigma), obs=y)\n",
" return mean"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a0700b08",
"metadata": {},
"outputs": [],
"source": [
"in_features = 1\n",
"out_features = 1\n",
"adam_params = {\"lr\": 0.0005, \"betas\": (0.90, 0.999)}"
]
},
{
"cell_type": "markdown",
"id": "09f69cd9",
"metadata": {},
"source": [
"## Base Setup for custom loop"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dd1a72d1",
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, y):\n",
" super().__init__()\n",
" self.x = x\n",
" self.y = y\n",
" \n",
" def __len__(self):\n",
" return self.x.shape[0]\n",
" \n",
" def __getitem__(self, index):\n",
" return self.x[index], self.y[index]\n",
"\n",
"dataset = Dataset(x_data, y_data)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f5a27f60",
"metadata": {},
"outputs": [],
"source": [
"def train(model, guide, X, Y, dataset, adam_params, n_epochs=5000):\n",
" # X, Y used only for parameter initialization\n",
" pyro.clear_param_store()\n",
" torch.manual_seed(42)\n",
" pyro.set_rng_seed(42)\n",
" \n",
" # Get params\n",
" with pyro.poutine.block(), pyro.poutine.trace(param_only=True) as param_capture:\n",
" guide(x=X, full_size=X.shape[0], y=Y)\n",
" params = list([pyro.param(name).unconstrained() for name in param_capture.trace])\n",
" \n",
" # Train\n",
" optimizer = torch.optim.Adam(params, **adam_params)\n",
" loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n",
" losses = []\n",
" for epoch in tqdm(range(n_epochs)):\n",
" epoch_loss = []\n",
" for batch in torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True):\n",
" x, y = batch\n",
" loss = loss_fn(model, guide, x, X.shape[0], y)\n",
" epoch_loss.append(loss.item())\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" losses.append(sum(epoch_loss) / len(epoch_loss))\n",
" \n",
" plt.plot(losses)"
]
},
{
"cell_type": "markdown",
"id": "45470edd",
"metadata": {},
"source": [
"## Using EasyGuide class in custom loop"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a086b33f",
"metadata": {},
"outputs": [],
"source": [
"class RegressionGuideAsClass(EasyGuide):\n",
" def __init__(self, model):\n",
" super().__init__(model)\n",
"\n",
" def guide(self, x, full_size, y=None):\n",
" group = self.group(match=\".*\")\n",
" loc = pyro.param(\"loc\", torch.randn(group.event_shape))\n",
" scale = pyro.param(\"scale\", torch.ones(group.event_shape)*0.01, constraint=constraints.positive)\n",
" group.sample(\"joint\", dist.Normal(loc=loc, scale=scale).to_event(1))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f50d0bf3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [00:53<00:00, 93.05it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAkrklEQVR4nO3deXxV9ZnH8c+ThbAFAQmLbAFkEagiREVRXHGt1dYZi50KTp2hWjut2o6CtmNtx461o50ytlpcRrSKWhVFxQVXtIIQZBeQVQhEiCA7hCzP/HFP4AoJCcm5uck93/frlVfOfc45v/P8eOmTc39n+Zm7IyIi0ZCW7ARERKT+qOiLiESIir6ISISo6IuIRIiKvohIhKjoi4hESLVF38y6mtm7ZrbEzBab2U+D+O/NbKmZLTCzyWbWOojnmtkeM5sX/DwY19YQM1toZivMbLyZWcJ6JiIih7Dq7tM3s05AJ3f/xMyygTnA5UAX4B13LzWz3wG4+61mlgu84u4DK2lrFvBTYCYwFRjv7q+F2B8RETmMjOo2cPdCoDBY3mFmS4DO7v5m3GYzgX84XDvBH49W7j4j+Pw4sT8ehy367dq189zc3OrSFBGROHPmzPnS3XMOjldb9OMFZ/EnAh8ftOoHwDNxn3uY2VxgO/ALd/8A6AwUxG1TEMQOKzc3l/z8/CNJU0Qk8szs88riNS76ZtYSeB640d23x8VvB0qBJ4NQIdDN3Teb2RDgRTMbAFQ2fl/p2JKZjQHGAHTr1q2mKYqISDVqdPeOmWUSK/hPuvsLcfHRwDeBf/Lg4oC7F7v75mB5DrAS6EPszL5LXLNdgA2VHc/dJ7h7nrvn5eQc8u1ERERqqSZ37xjwCLDE3e+Li18I3Ap8y913x8VzzCw9WO4J9AZWBdcGdpjZ0KDNUcBLofZGREQOqybDO8OAq4GFZjYviN0GjAeygGnBnZcz3f06YDjwazMrBcqA69x9S7Df9cBjQDNiF3B1546ISD2qyd07H1L5ePzUKrZ/nthQUGXr8oFDbuUUEZH6oSdyRUQiREVfRCRCUrbo79lXxgufFKCZwUREDjiih7Mak/989VOe/HgtHVs15bRj2yU7HRGRBiFlz/Q3bi8GYEdxaZIzERFpOFK26Fe8v1OjOyIiB6Ru0U92AiIiDVDKFv0DdKovIlIhZYu+pmcRETlUyhb9ChrTFxE5IGWLvmlUX0TkEKlb9Cvu3kluGiIiDUrqF31VfRGR/VK26IuIyKFSvui7BnhERPZL2aKvC7kiIoeqyXSJXc3sXTNbYmaLzeynQbytmU0zs+XB7zZx+4wzsxVmtszMLoiLDzGzhcG68cG0iQmlMX0RkQNqcqZfCvzM3Y8DhgI3mFl/YCzwtrv3Bt4OPhOsGwkMAC4E/lwxZy7wADCG2Ly5vYP1CfHqwkIAPlhelKhDiIg0OtUWfXcvdPdPguUdwBKgM3AZMDHYbCJwebB8GfC0uxe7+2pgBXCymXUCWrn7DI+95P7xuH0S5tn8gkQfQkSk0TiiMX0zywVOBD4GOrh7IcT+MADtg806A+vidisIYp2D5YPjIiJST2pc9M2sJbEJz2909+2H27SSmB8mXtmxxphZvpnlFxVpeEZEJCw1Kvpmlkms4D/p7i8E4Y3BkA3B701BvADoGrd7F2BDEO9SSfwQ7j7B3fPcPS8nJ6emfRERkWrU5O4dAx4Blrj7fXGrpgCjg+XRwEtx8ZFmlmVmPYhdsJ0VDAHtMLOhQZuj4vZJqMlzNa4vIgI1O9MfBlwNnGNm84Kfi4G7gRFmthwYEXzG3RcDzwKfAq8DN7h7WdDW9cDDxC7urgReC7MzVbnpmfn1cRgRkQav2onR3f1Dqp6I6twq9rkLuKuSeD4w8EgSFBGR8KTsE7kiInIoFX0RkQiJTNEvLStnb0lZ9RuKiKSwyBT9Kx6cQb9fvp7sNEREkioyRX/+uq3JTkFEJOkiU/RFRCSiRf+zjTvYva802WmIiNS7yBX9iR+t4fw/TOeHT8xJdioiIvUuckX/jimLAfh41ZYkZyIiUv8iV/T302yKIhJB0S36IiIRpKIvIhIhkS36Gt0RkSiKbtFX1ReRCIps0RcRiSIVfRGRCKnJdImPmtkmM1sUF3smbhatNWY2L4jnmtmeuHUPxu0zxMwWmtkKMxsfTJmYNKZRfRGJoGpnzgIeA+4HHq8IuPt3K5bN7F5gW9z2K919UCXtPACMAWYCU4ELqafpEiuzp6SMsnInPU3FX0Sio9ozfXefDlT6+Gpwtn4lMOlwbZhZJ6CVu89wdyf2B+TyI842ZG8s/iLZKYiI1Ku6jumfAWx09+VxsR5mNtfM3jezM4JYZ6AgbpuCIJZUP3ryE3LHvqrJVUQkMmoyvHM4V/H1s/xCoJu7bzazIcCLZjaAym+L96oaNbMxxIaC6NatWx1TrN6u4lKaZqYn/DgiIslW6zN9M8sAvgM8UxFz92J33xwszwFWAn2Indl3idu9C7ChqrbdfYK757l7Xk5OTm1TFBGRg9RleOc8YKm77x+2MbMcM0sPlnsCvYFV7l4I7DCzocF1gFHAS3U4dqiq/MohIpJianLL5iRgBtDXzArM7Npg1UgOvYA7HFhgZvOB54Dr3L3iIvD1wMPACmLfAJJ2587Bxjyen+wURETqRbVj+u5+VRXxayqJPQ88X8X2+cDAI8yvXnyydmuyUxARqRd6IldEJEJU9EVEIkRFX0QkQlT0RUQiREVfRCRCVPRFRCJERV9EJEJStuj3ymmR7BRERBqclC36L94wLNkpiIg0OClb9LObZtKznc72RUTipWzRB5j8I53ti4jES+mif1TzzGSnICLSoKR00RcRka9T0RcRiRAV/cCCgq3kjn2VkrLyZKciIpIwKvqBb93/dwAmfrQmuYmIiCRQTWbOetTMNpnZorjYr8xsvZnNC34ujls3zsxWmNkyM7sgLj7EzBYG68YH0yY2OPt0pi8iKawmZ/qPARdWEv+Duw8KfqYCmFl/YtMoDgj2+XPFnLnAA8AYYvPm9q6iTRERSaBqi767Twe2VLdd4DLgaXcvdvfVxObDPdnMOgGt3H2GuzvwOHB5LXNOqG27S5KdgohIwtRlTP/HZrYgGP5pE8Q6A+vitikIYp2D5YPjDc5fpq9KdgoiIglT26L/ANALGAQUAvcG8crG6f0w8UqZ2Rgzyzez/KKiolqmGNMyq9q530VEIqNWRd/dN7p7mbuXAw8BJwerCoCucZt2ATYE8S6VxKtqf4K757l7Xk5OTm1S3O/Vn5xep/1FRFJJrYp+MEZf4dtAxZ09U4CRZpZlZj2IXbCd5e6FwA4zGxrctTMKeKkOeddY96OP/KVrK4t2JiATEZHkq8ktm5OAGUBfMysws2uBe4LbLxcAZwM3Abj7YuBZ4FPgdeAGdy8LmroeeJjYxd2VwGthdyYs5977frJTEBFJiGoHvN39qkrCjxxm+7uAuyqJ5wMDjyg7EREJlZ7IFRGJEBV9EZEIUdGvQq/bpu5fnvP5Frbt0UNbItL4Raron3dc+xpvW1buzFi5mbc+3cgVD8zgXybOTmBmIiL1IxJPLj31r6dQUuYM792OHuOmVr9D4KqHZu5fXrR+eyJSExGpV5E40z+tVzvO7JNDXV7suaekjN+9vjTErERE6l8kin68S084ptb7PvDeyhAzERGpf5Er+vddeUKyUxARSZrIFf3M9Mh1WURkP1VAEZEIUdEXEYmQSBb9D245O9kpiIgkRSSLfrMm6dVvJCKSgiJZ9Gt/t76ISOMWyaLfqllmslMQEUmKSBb9zPQ0fnRWr1rt+8sXF/HEjDXhJiQiUk9qMnPWo2a2ycwWxcV+b2ZLzWyBmU02s9ZBPNfM9pjZvODnwbh9hgSzba0ws/FWl3cihOCWC/txQtfWR7zfEzM/55cvLQ4/IRGRelCTM/3HgAsPik0DBrr78cBnwLi4dSvdfVDwc11c/AFgDLF5c3tX0mb9c092BiIi9araou/u04EtB8XedPfS4ONMoMvh2ggmUm/l7jPc3YHHgctrlbGIiNRaGGP6P+Drk5z3MLO5Zva+mZ0RxDoDBXHbFASxRmv4Pe/y4fIvk52GiMgRqVPRN7PbgVLgySBUCHRz9xOBm4GnzKwVld8lWeXYipmNMbN8M8svKiqqS4qHVZfBnbVbdvP9Rz5m9KOzQstHRCTRal30zWw08E3gn4IhG9y92N03B8tzgJVAH2Jn9vFDQF2ADVW17e4T3D3P3fNycnJqm2K9eP+zxP1REhEJW62KvpldCNwKfMvdd8fFc8wsPVjuSeyC7Sp3LwR2mNnQ4K6dUcBLdc6+jq47s3a3bYqINFY1uWVzEjAD6GtmBWZ2LXA/kA1MO+jWzOHAAjObDzwHXOfuFReBrwceBlYQ+wYQfx0gKS7+Rie9h0dEIqXaOXLd/apKwo9Use3zwPNVrMsHBh5RdvWga9vmyU5BRKTeRPKJXBGRqFLRD8Hv39CE6SLSOKjoh+BP72rCdBFpHFT0Q/Knd1ckOwURkWqp6AOL77yAP44cVKc2fv/GsnCSERFJIBV9oEVWBhlpdf+n2FdaHkI2IiKJo6If6NW+BQDNMms/leK1E2eHlY6ISEKo6Af6dWzFnF+cx6e/vqDWbXyw/Euem1NQ/YYiIkmioh/n6JZZmBn3XHF8rdv4+d/mU1xaBsBrCwvZW1IWVnoiInWmol+JK0/qWuc25q79iuuf/IQ7X9YsWyLScKjoJ8j/vhO7hbPgqz1JzkRE5AAV/QR5Z+mm/csrNu1g256SJGYjIhKjop8Ae/d9/dbN8+6bzhUPfJSkbEREDlDRT4CL/jh9/3LF3OsrNu1MUjYiIgeo6CfAhm17k52CiEilVPSr8N7PzwqlnQ9XaPJ0EWk4VPSrkNuuBQ+NyuOsvuHN0XvOve/xLxPzQ2tPRORI1WS6xEfNbJOZLYqLtTWzaWa2PPjdJm7dODNbYWbLzOyCuPgQM1sYrBsfzJXboI3o34FWTTNDa29V0S7eWrIxtPZERI5UTc70HwMuPCg2Fnjb3XsDbwefMbP+wEhgQLDPnysmSgceAMYQmyy9dyVtNkie7AREREJUbdF39+nAloPClwETg+WJwOVx8afdvdjdVxObBP1kM+sEtHL3Ge7uwONx+zRo3z+lGwB3f+cbSc5ERKTuajum38HdCwGC3+2DeGdgXdx2BUGsc7B8cLxSZjbGzPLNLL+oqKiWKYbjlJ5Hs+buS/j24M50bdsslDZzx77Kib9+k6c+XhtKeyIiNRX2hdzKxun9MPFKufsEd89z97ycnPAupNZFVkY6H9xyDnd+a0Ao7X21u4TbJi8MpS0RkZqqbdHfGAzZEPyueOdAARD/trIuwIYg3qWSeKNzTr/Yl5r22VlJzkRE5MjVtuhPAUYHy6OBl+LiI80sy8x6ELtgOysYAtphZkODu3ZGxe3TqBzTuhmXHN+Jh0blcfOIPslOR0TkiNTkls1JwAygr5kVmNm1wN3ACDNbDowIPuPui4FngU+B14Eb3L3ihfLXAw8Tu7i7Engt5L7Ui/Q040/fG8wJXVtT7nW/t+cjPbwlIvUoo7oN3P2qKladW8X2dwF3VRLPBwYeUXYNXHl53Yv+9x7+mDV3XxJCNiIi1dMTuXUQQs0XEalXKvp1cEzr2C2cZ9fxVQ09x73Kwx+sYuN2vahNRBLLPIRx6UTKy8vz/PyG+b4ad+f9z4o4s08OPcZNDaVNDfWISBjMbI675x0cr3ZMX6pmZpzVt331G4qINBAa3mlgHnhvJQ3925eINF4q+g3M715fGtpQkYjIwVT0Q9KxVdNkpyAiUi0V/ZC8eMMw/u+ak8jr3qb6jWvg7SUb+e5fZlC4bY+Ge0QkNCr6Iel4VFPO7tee564/jZwQ3stz7cR8Pl69hVP/6x0e/fsaAJ7NX8fqL3fVuW0RiS4V/QR47+dn8c7Pzgytvb8Hr2q45bkFXPq/H4bWrohEj4p+ArTIyqBnTsvQ2vt884Gz+53FpaG1KyLRo6KfQBOuHhJKOyuLdvHy/Eb5JmoRaWBU9BPo/AEdGXtRv1Da+rdJc0NpR0SiTUU/wcac0ZPzjusAQO7RzUNps7i0jGVf7AilLRGJFhX9BEtLMx4encdDo/JCO+u/ffIiLvif6RTtKA6lPRGJDhX9ejKifweaZBz4527dPLPWbT03JzbH/G2TF5I79tVDJljfvLOYj1dtrnX7IpK6al30zayvmc2L+9luZjea2a/MbH1c/OK4fcaZ2QozW2ZmF4TThcbpD1cOqnMb0z7dCHDIBOvfnTCT706YWef2RST11Lrou/sydx/k7oOAIcBuYHKw+g8V69x9KoCZ9QdGAgOAC4E/m1l6nbJvZE7r1W7/8tn9wn07Z+7YV/ef3a/YtDPUtkUkdYQ1vHMusNLdPz/MNpcBT7t7sbuvJjZX7skhHb9RaJqZzvK7LmL+HecDcN2ZvUJtf+KMNV/77O4Ul5ZVvrGIRFJYRX8kMCnu84/NbIGZPWpmFS+j6Qysi9umIIgdwszGmFm+meUXFRWFlGLDkJmexlHNYuP5Yy/qR58OBx7iOr7LUXVqe+rCL1i3Zff+z7dNXkTfX7zO7DVb6tSuiKSOOhd9M2sCfAv4WxB6AOgFDAIKgXsrNq1k90rfJObuE9w9z93zcnLqNhVhQ/fmTQde1zC4W91f1nbGPe/uX540K3aBVxd1RaRCGGf6FwGfuPtGAHff6O5l7l4OPMSBIZwCoGvcfl0APWYqIlKPwij6VxE3tGNmneLWfRtYFCxPAUaaWZaZ9QB6A7NCOH7KSNQrlP/7zc8S0q6IND51miPXzJoDI4AfxoXvMbNBxIZu1lSsc/fFZvYs8ClQCtzg7rrKCDz1L6fw/mdFpKdVNgImIhIea+gTdOTl5Xl+fn6y06gXe0vKuG3yQob3zmHy3PW8/1m4F7EfGpXHiP4dQm1TRBomM5vj7nmHxFX0G67csa+G3ubPRvThn0/vQcusr3/J27JrH62aZpCRroe0RVJBVUW/TsM7klgj+nfY/9RtWO6d9hkfrdzMpDFDuX3yQp6Me4XDVSd35b++c3yoxxORhkWndQ3YQ6MO+SMdihmrNvPy/A1fK/gAk2at48PlXybkmCLSMKjoN3D3XHE8g7u15q/XnsKwY49myo+HhdJuVe/n//4jH4fSvog0TBrTb2SWFG7noj9+kPDjrPztxbqbSKQRq2pMX2f6jUy/jtn1cpxd+zQXr0gqUtFvZMyM8VedCMDLPz6dOy7tn5Dj7CstT0i7IpJcGt5ppLbtKdn/4jYI//bOkSd15Zx+7Vm+aSfXnt6DppmRegu2SKOn4Z0UE1/wAWbffl6o7T89ex1jnpjD799Yxh0vLQZgQcFWRj06S98CRBoxFf0UkZOdxciTYu+zC+sOnwrP5K+jpKycf//bAqZ/VsTKIk3SItJYqeinkLuvOJ6Vv72Y47u05qFRefxx5KDQ2p67disbtu4B4N+fm8/eEr02SaQx0ph+iisrd3rdNjX0dptlpvPsD09lZ3Ep7y3bxLiLjwv9GCJSe3oNQ0Ql6l77PSVlXHr/h/s/VxT9TzdsZ29pWSgTwohI+FT0I+CDW85mT0kZvXJaMvb5BfxtTkHoxygpKyczPY2Lx8ceHFtz9yWhH0NE6k5j+hHQtW1z+nTIJj3N+P0/nsApPdqGfozBv5l2SGzj9r2sD64DiEjDoDH9CFu/dQ/D7n6Hpplp7C0J9zbMu749kNsnxyZNm3/H+WSmG1kZ6Xq1g0g9Scj79M1sDbADKANK3T3PzNoCzwC5xGbOutLdvwq2HwdcG2z/E3d/o7pjqOgn3idrv+I7f/4o4cf5zomdue+7gxJ+HBFJ7IXcs909/n28Y4G33f1uMxsbfL7VzPoDI4EBwDHAW2bWR1MmJt/gbm147J9PYnD3NhTtKGZBwVZuemZ+6Md5Ye56zu7XnktPOCb0tkWkZhIxpn8ZMDFYnghcHhd/2t2L3X01sAI4OQHHl1o4q297WjXNpFdOS759Yhfm/ccIXvjRaaEf598mzaXgq90AvLawkCdmrKHPL15jV7Fe8CZSH+p6pu/Am2bmwF/cfQLQwd0LAdy90MzaB9t2BmbG7VsQxKQBat28CYO7NWHRnRfwxbY9nHff9NDaPv137x4SW7N5FwOOOYqK4UYzjf2LJEJdz/SHuftg4CLgBjMbfphtK/u/uNILCmY2xszyzSy/qCjcycHlyLTMyuDY9tm8ceNw7rkicVMpXjL+Q5Zv3EGPcVN5eva6hB1HJOrqdKbv7huC35vMbDKx4ZqNZtYpOMvvBGwKNi8Ausbt3gXYUEW7E4AJELuQW5ccJRx9O2bTt2M2J3ZrTdPMdLq2bc4pv32LjduLQzvGiD/Evk2Me2Ehizds46MVm5l8wzDS0+yQidxFpHZqfaZvZi3MLLtiGTgfWARMAUYHm40GXgqWpwAjzSzLzHoAvYFZtT2+JEfvDtl0bdscgL/feg5jhvcE4Budjwr1OH+duZZVX+7ihDvfZOAdb7Bx+1527C2huFTX/UXqota3bJpZT2By8DEDeMrd7zKzo4FngW7AWuAf3X1LsM/twA+AUuBGd3+tuuPols3Go6SsnOfnFPDYR2tY+sWOhBxjYOdWvPJvZySkbZFUkpD79OuDin7jtKRwOy/OW89f3l8FwHPXnco/PDgjlLZHn9qdOy4dwBfb9zJr9RbO7tf+kPkFRKJORV+S7tnZ67jl+QUJafutm4fTuXVz1m7ZTa+cFjz84WquOS1XM35JZKnoS4Pyt/x1nNrr6P23b7ZqmsH2veHcq/+LS47jP19dwvA+OYw+tTvnHteBmas2c3JuW9L0GgiJCBV9aZC27S6hRVY6GelpvLawkHGTF7J1d0lCjjXs2KOZcHUeLXQnkESAir40Oq8s2EDn1s3ISEvjzpcXk//5V6G0+83jO3Hjeb05777p/PDMnlw9tDtd2jQPpW2RhkJFX1LCnM+38Nyc9UyatTYh7f/DkC48N6eAR0bnce5xHRJyDJH6oKIvKaVijH715l0sLdxBy6YZjHt+ARu27Q3tGMe2b8lJuW343sndKXcnt10L3SUkjYaKvkRC7AGucgwY9egsikvLWbFpZ6jH+PvYc8hpmcXjM9Zw6QnHsGNvCVkZ6fsfWvti214cp9NRzUI9rsiRUNGXyFtSuJ3WzTNZWLCNMU/MScgxBnZuxaL12wGYMe4cdhWXcWz7lgk5lsjhqOiLVGJncSktszL4fPMu1m/dQ0ZaGlf+JZyHyCrzm8sG0KVNcwZ1bU2zJrFnCJpmplNe7rqdVEKloi9SS6Vl5bzwyXrWbN7FksLtpKel8daSjQk95vPXnxq8apr9fxxEjkQiZ84SSWkZ6WlceVLXStftKi6leZN0VhbtZMbKzeRkZ/Hi3A0Ul5bx7rLavxb8igcO/22jfXYWF3+jE306ZNOjXQuG9oxNdq95CKQ6OtMXqUd7S8rYvreEj1Zs5pUFhaQZvPlpYr41DOramvMHdMAdhvY8mt37SunWtjnHtI5dYE4305BSCtPwjkgjsnH7XvaVllNSVs6Lc9ezZvNupsyvdPqJUJzTrz1jhvekT4ds1n+1h94dWpKVkUZpuZORZvoG0Qip6IukmC279pGeZry9ZCOF2/by9Oy1rNuyJ2HHG9G/A0O6t+F7p3SjVVM9r9DQqeiLRMjO4lI+37yLPh2yeXVBIX07ZnPRHz9I2PF+fn4fjm2fTbe2zel/TCv2lpTRNDOdfaXlbN5VfETPLLg7byz+gnOP60Bmel1ndI0uFX0RAWJF9bVFX3B673bc9PQ83l66ibP65vBeHS4819RrPz2DppnptGqawVtLNjK4Wxu6Hd2cNDMu/d8PGXtRP/aVljPmiTncPKIPPzm3d8JzSlWhF30z6wo8DnQEyoEJ7v5HM/sV8K9AxX9Bt7n71GCfccC1QBnwE3d/o7rjqOiLJMevpizmsY/WJO34Vw/tzm8uHwjAg++vZNbqLTx6zUlJy6exSUTR7wR0cvdPgrly5wCXA1cCO939vw/avj8widjk6ccAbwF93P2wk56q6Iskz559ZRR8tZveHbLZW1LG+X+YTt+O2UyLu+OoV04LVhbtqte8mmSk8e/n9+Vfh/dk2+4S7p22jMdnfM5HY89h+aadnNitNS2aZLD0i+1s31PKqb2OpqSsPFJ3LCV8eMfMXgLuB4ZRedEfB+Du/xV8fgP4lbsf9oZkFX2RhmvTjr20z24KxF6Cd+1js3l49Elc9dDMJGf2dX/63mBueOoTAB78/hCu++scbr2wH92Pbs7pvdvRqmkmu4pLyUg3sjIqfxiuuLSMzLS0Sv9obNi6h6OaZTaouRoSWvTNLBeYDgwEbgauAbYD+cDP3P0rM7sfmOnufw32eQR4zd2fO1zbKvoijc/bSzays7iUjLQ0crKz2LJrH13bNuOS8R/u36ZDqyw2bi9OYpaV++GZPfnL+6u4/3snsm7LHk7ochSDu7eh3y9f/9p2U39yBu2ym/Dphu1c83+zAVhz9yUAlJc7Zgcelnt36SaG5Lap17ueElb0zawl8D5wl7u/YGYdgC8BB35DbAjoB2b2J2DGQUV/qrs/X0mbY4AxAN26dRvy+eef1ylHEWkYikvLWFiwjSHd2+wviO5OcWk5melpjH97OSs27eSib3Tkx0/NpV/HbJZ+sSPJWdfcid1aM3ft1v2ffzCsBy8v2EDRjmLO6N2OWy/sx50vL+aa03owvE87Nu/cx659paSnGf06tgo1l4QUfTPLBF4B3nD3+ypZnwu84u4DNbwjIrVVVu4s3rCNAcccRXowvLK3pIyPV2+hcOseHnx/JWs27+ayQcfwxba9fLx6S5IzPnJ//qfBbN1dwtOz19KnQzY3j+iz/+np2kjEhVwDJgJb3P3GuHgndy8Mlm8CTnH3kWY2AHiKAxdy3wZ660KuiIRtV3Hp/vH1ncWl3PP6Uq4Y3IVj27dk7ZbdCX1mIUwVw0W1kYgXrg0DrgYWmtm8IHYbcJWZDSI2vLMG+CGAuy82s2eBT4FS4IbqCr6ISG3EX1BtmZXBry8buP/zcZ1asebuS9hbUkaT9NiF2b0lZaSn2f5XTrzwSQH/89Zy3rxpOCVl5Qz69TTKyp0pPx7GpFnreGb2Wsrr4RGnRLxyWw9niYhUY19pOWXl/rXXXBd8tZsubZp/bTt3Z/aar2jdPJNeOS1JM3CH305dQkZ6Guu37iGnZRY52Vnc88ZSqiu/c385gjYtmtQqZz2RKyLSAC1av41fTVnM+QM6MOrUXCZ+tIYWWRl8f2j3OrWroi8iEiFVFX29zUhEJEJU9EVEIkRFX0QkQlT0RUQiREVfRCRCVPRFRCJERV9EJEJU9EVEIqTBP5xlZkVAbd+t3I7Ya56jRH2Ohqj1OWr9hbr3ubu75xwcbPBFvy7MLL+yJ9JSmfocDVHrc9T6C4nrs4Z3REQiREVfRCRCUr3oT0h2AkmgPkdD1Poctf5Cgvqc0mP6IiLydal+pi8iInFSsuib2YVmtszMVpjZ2GTnUxdm9qiZbTKzRXGxtmY2zcyWB7/bxK0bF/R7mZldEBcfYmYLg3XjgzmOGyQz62pm75rZEjNbbGY/DeIp228za2pms8xsftDnO4N4yvYZwMzSzWyumb0SfE71/q4Jcp1nZvlBrH777O4p9QOkAyuBnkATYD7QP9l51aE/w4HBwKK42D3A2GB5LPC7YLl/0N8soEfw75AerJsFnAoY8BpwUbL7dpg+dwIGB8vZwGdB31K230F+LYPlTOBjYGgq9znI9WbgKeCViPy3vQZod1CsXvucimf6JwMr3H2Vu+8DngYuS3JOtebu04EtB4UvAyYGyxOBy+PiT7t7sbuvBlYAJ5tZJ6CVu8/w2H8xj8ft0+C4e6G7fxIs7wCWAJ1J4X57zM7gY2bw46Rwn82sC3AJ8HBcOGX7exj12udULPqdgXVxnwuCWCrp4O6FECuQQPsgXlXfOwfLB8cbPDPLBU4kduab0v0OhjrmAZuAae6e6n3+H+AWoDwulsr9hdgf8jfNbI6ZjQli9drnjFom3pBVNrYVlVuUqup7o/w3MbOWwPPAje6+/TDDlinRb3cvAwaZWWtgspkNPMzmjbrPZvZNYJO7zzGzs2qySyWxRtPfOMPcfYOZtQemmdnSw2ybkD6n4pl+AdA17nMXYEOSckmUjcFXPILfm4J4VX0vCJYPjjdYZpZJrOA/6e4vBOGU7zeAu28F3gMuJHX7PAz4lpmtITYEe46Z/ZXU7S8A7r4h+L0JmExsOLpe+5yKRX820NvMephZE2AkMCXJOYVtCjA6WB4NvBQXH2lmWWbWA+gNzAq+Mu4ws6HBVf5Rcfs0OEGOjwBL3P2+uFUp228zywnO8DGzZsB5wFJStM/uPs7du7h7LrH/R99x9++Tov0FMLMWZpZdsQycDyyivvuc7KvZifgBLiZ2x8dK4PZk51PHvkwCCoESYn/hrwWOBt4Glge/28Ztf3vQ72XEXdEH8oL/wFYC9xM8mNcQf4DTiX1dXQDMC34uTuV+A8cDc4M+LwL+I4inbJ/j8j2LA3fvpGx/id1ROD/4WVxRm+q7z3oiV0QkQlJxeEdERKqgoi8iEiEq+iIiEaKiLyISISr6IiIRoqIvIhIhKvoiIhGioi8iEiH/D1NIdrSLFL4EAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"base_regression_model = BayesianRegression(in_features, out_features)\n",
"regression_guide_as_class = RegressionGuideAsClass(base_regression_model)\n",
"train(base_regression_model, regression_guide_as_class, x_data, y_data, dataset, adam_params)"
]
},
{
"cell_type": "markdown",
"id": "5a6a737e",
"metadata": {},
"source": [
"## Using easy_guide decorator in custom loop"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2891c910",
"metadata": {},
"outputs": [],
"source": [
"@easy_guide(base_regression_model)\n",
"def regression_guide_with_decorator(self, x, full_size, y=None):\n",
" group = self.group(match=\".*\")\n",
" loc = pyro.param(\"loc\", torch.randn(group.event_shape))\n",
" scale = pyro.param(\"scale\", torch.ones(group.event_shape)*0.01, constraint=constraints.positive)\n",
" group.sample(\"joint\", dist.Normal(loc=loc, scale=scale).to_event(1))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "df1fa914",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [00:54<00:00, 91.90it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAkrklEQVR4nO3deXxV9ZnH8c+ThbAFAQmLbAFkEagiREVRXHGt1dYZi50KTp2hWjut2o6CtmNtx461o50ytlpcRrSKWhVFxQVXtIIQZBeQVQhEiCA7hCzP/HFP4AoJCcm5uck93/frlVfOfc45v/P8eOmTc39n+Zm7IyIi0ZCW7ARERKT+qOiLiESIir6ISISo6IuIRIiKvohIhKjoi4hESLVF38y6mtm7ZrbEzBab2U+D+O/NbKmZLTCzyWbWOojnmtkeM5sX/DwY19YQM1toZivMbLyZWcJ6JiIih7Dq7tM3s05AJ3f/xMyygTnA5UAX4B13LzWz3wG4+61mlgu84u4DK2lrFvBTYCYwFRjv7q+F2B8RETmMjOo2cPdCoDBY3mFmS4DO7v5m3GYzgX84XDvBH49W7j4j+Pw4sT8ehy367dq189zc3OrSFBGROHPmzPnS3XMOjldb9OMFZ/EnAh8ftOoHwDNxn3uY2VxgO/ALd/8A6AwUxG1TEMQOKzc3l/z8/CNJU0Qk8szs88riNS76ZtYSeB640d23x8VvB0qBJ4NQIdDN3Teb2RDgRTMbAFQ2fl/p2JKZjQHGAHTr1q2mKYqISDVqdPeOmWUSK/hPuvsLcfHRwDeBf/Lg4oC7F7v75mB5DrAS6EPszL5LXLNdgA2VHc/dJ7h7nrvn5eQc8u1ERERqqSZ37xjwCLDE3e+Li18I3Ap8y913x8VzzCw9WO4J9AZWBdcGdpjZ0KDNUcBLofZGREQOqybDO8OAq4GFZjYviN0GjAeygGnBnZcz3f06YDjwazMrBcqA69x9S7Df9cBjQDNiF3B1546ISD2qyd07H1L5ePzUKrZ/nthQUGXr8oFDbuUUEZH6oSdyRUQiREVfRCRCUrbo79lXxgufFKCZwUREDjiih7Mak/989VOe/HgtHVs15bRj2yU7HRGRBiFlz/Q3bi8GYEdxaZIzERFpOFK26Fe8v1OjOyIiB6Ru0U92AiIiDVDKFv0DdKovIlIhZYu+pmcRETlUyhb9ChrTFxE5IGWLvmlUX0TkEKlb9Cvu3kluGiIiDUrqF31VfRGR/VK26IuIyKFSvui7BnhERPZL2aKvC7kiIoeqyXSJXc3sXTNbYmaLzeynQbytmU0zs+XB7zZx+4wzsxVmtszMLoiLDzGzhcG68cG0iQmlMX0RkQNqcqZfCvzM3Y8DhgI3mFl/YCzwtrv3Bt4OPhOsGwkMAC4E/lwxZy7wADCG2Ly5vYP1CfHqwkIAPlhelKhDiIg0OtUWfXcvdPdPguUdwBKgM3AZMDHYbCJwebB8GfC0uxe7+2pgBXCymXUCWrn7DI+95P7xuH0S5tn8gkQfQkSk0TiiMX0zywVOBD4GOrh7IcT+MADtg806A+vidisIYp2D5YPjIiJST2pc9M2sJbEJz2909+2H27SSmB8mXtmxxphZvpnlFxVpeEZEJCw1Kvpmlkms4D/p7i8E4Y3BkA3B701BvADoGrd7F2BDEO9SSfwQ7j7B3fPcPS8nJ6emfRERkWrU5O4dAx4Blrj7fXGrpgCjg+XRwEtx8ZFmlmVmPYhdsJ0VDAHtMLOhQZuj4vZJqMlzNa4vIgI1O9MfBlwNnGNm84Kfi4G7gRFmthwYEXzG3RcDzwKfAq8DN7h7WdDW9cDDxC7urgReC7MzVbnpmfn1cRgRkQav2onR3f1Dqp6I6twq9rkLuKuSeD4w8EgSFBGR8KTsE7kiInIoFX0RkQiJTNEvLStnb0lZ9RuKiKSwyBT9Kx6cQb9fvp7sNEREkioyRX/+uq3JTkFEJOkiU/RFRCSiRf+zjTvYva802WmIiNS7yBX9iR+t4fw/TOeHT8xJdioiIvUuckX/jimLAfh41ZYkZyIiUv8iV/T302yKIhJB0S36IiIRpKIvIhIhkS36Gt0RkSiKbtFX1ReRCIps0RcRiSIVfRGRCKnJdImPmtkmM1sUF3smbhatNWY2L4jnmtmeuHUPxu0zxMwWmtkKMxsfTJmYNKZRfRGJoGpnzgIeA+4HHq8IuPt3K5bN7F5gW9z2K919UCXtPACMAWYCU4ELqafpEiuzp6SMsnInPU3FX0Sio9ozfXefDlT6+Gpwtn4lMOlwbZhZJ6CVu89wdyf2B+TyI842ZG8s/iLZKYiI1Ku6jumfAWx09+VxsR5mNtfM3jezM4JYZ6AgbpuCIJZUP3ryE3LHvqrJVUQkMmoyvHM4V/H1s/xCoJu7bzazIcCLZjaAym+L96oaNbMxxIaC6NatWx1TrN6u4lKaZqYn/DgiIslW6zN9M8sAvgM8UxFz92J33xwszwFWAn2Indl3idu9C7ChqrbdfYK757l7Xk5OTm1TFBGRg9RleOc8YKm77x+2MbMcM0sPlnsCvYFV7l4I7DCzocF1gFHAS3U4dqiq/MohIpJianLL5iRgBtDXzArM7Npg1UgOvYA7HFhgZvOB54Dr3L3iIvD1wMPACmLfAJJ2587Bxjyen+wURETqRbVj+u5+VRXxayqJPQ88X8X2+cDAI8yvXnyydmuyUxARqRd6IldEJEJU9EVEIkRFX0QkQlT0RUQiREVfRCRCVPRFRCJERV9EJEJStuj3ymmR7BRERBqclC36L94wLNkpiIg0OClb9LObZtKznc72RUTipWzRB5j8I53ti4jES+mif1TzzGSnICLSoKR00RcRka9T0RcRiRAV/cCCgq3kjn2VkrLyZKciIpIwKvqBb93/dwAmfrQmuYmIiCRQTWbOetTMNpnZorjYr8xsvZnNC34ujls3zsxWmNkyM7sgLj7EzBYG68YH0yY2OPt0pi8iKawmZ/qPARdWEv+Duw8KfqYCmFl/YtMoDgj2+XPFnLnAA8AYYvPm9q6iTRERSaBqi767Twe2VLdd4DLgaXcvdvfVxObDPdnMOgGt3H2GuzvwOHB5LXNOqG27S5KdgohIwtRlTP/HZrYgGP5pE8Q6A+vitikIYp2D5YPjDc5fpq9KdgoiIglT26L/ANALGAQUAvcG8crG6f0w8UqZ2Rgzyzez/KKiolqmGNMyq9q530VEIqNWRd/dN7p7mbuXAw8BJwerCoCucZt2ATYE8S6VxKtqf4K757l7Xk5OTm1S3O/Vn5xep/1FRFJJrYp+MEZf4dtAxZ09U4CRZpZlZj2IXbCd5e6FwA4zGxrctTMKeKkOeddY96OP/KVrK4t2JiATEZHkq8ktm5OAGUBfMysws2uBe4LbLxcAZwM3Abj7YuBZ4FPgdeAGdy8LmroeeJjYxd2VwGthdyYs5977frJTEBFJiGoHvN39qkrCjxxm+7uAuyqJ5wMDjyg7EREJlZ7IFRGJEBV9EZEIUdGvQq/bpu5fnvP5Frbt0UNbItL4Raron3dc+xpvW1buzFi5mbc+3cgVD8zgXybOTmBmIiL1IxJPLj31r6dQUuYM792OHuOmVr9D4KqHZu5fXrR+eyJSExGpV5E40z+tVzvO7JNDXV7suaekjN+9vjTErERE6l8kin68S084ptb7PvDeyhAzERGpf5Er+vddeUKyUxARSZrIFf3M9Mh1WURkP1VAEZEIUdEXEYmQSBb9D245O9kpiIgkRSSLfrMm6dVvJCKSgiJZ9Gt/t76ISOMWyaLfqllmslMQEUmKSBb9zPQ0fnRWr1rt+8sXF/HEjDXhJiQiUk9qMnPWo2a2ycwWxcV+b2ZLzWyBmU02s9ZBPNfM9pjZvODnwbh9hgSzba0ws/FWl3cihOCWC/txQtfWR7zfEzM/55cvLQ4/IRGRelCTM/3HgAsPik0DBrr78cBnwLi4dSvdfVDwc11c/AFgDLF5c3tX0mb9c092BiIi9araou/u04EtB8XedPfS4ONMoMvh2ggmUm/l7jPc3YHHgctrlbGIiNRaGGP6P+Drk5z3MLO5Zva+mZ0RxDoDBXHbFASxRmv4Pe/y4fIvk52GiMgRqVPRN7PbgVLgySBUCHRz9xOBm4GnzKwVld8lWeXYipmNMbN8M8svKiqqS4qHVZfBnbVbdvP9Rz5m9KOzQstHRCTRal30zWw08E3gn4IhG9y92N03B8tzgJVAH2Jn9vFDQF2ADVW17e4T3D3P3fNycnJqm2K9eP+zxP1REhEJW62KvpldCNwKfMvdd8fFc8wsPVjuSeyC7Sp3LwR2mNnQ4K6dUcBLdc6+jq47s3a3bYqINFY1uWVzEjAD6GtmBWZ2LXA/kA1MO+jWzOHAAjObDzwHXOfuFReBrwceBlYQ+wYQfx0gKS7+Rie9h0dEIqXaOXLd/apKwo9Use3zwPNVrMsHBh5RdvWga9vmyU5BRKTeRPKJXBGRqFLRD8Hv39CE6SLSOKjoh+BP72rCdBFpHFT0Q/Knd1ckOwURkWqp6AOL77yAP44cVKc2fv/GsnCSERFJIBV9oEVWBhlpdf+n2FdaHkI2IiKJo6If6NW+BQDNMms/leK1E2eHlY6ISEKo6Af6dWzFnF+cx6e/vqDWbXyw/Euem1NQ/YYiIkmioh/n6JZZmBn3XHF8rdv4+d/mU1xaBsBrCwvZW1IWVnoiInWmol+JK0/qWuc25q79iuuf/IQ7X9YsWyLScKjoJ8j/vhO7hbPgqz1JzkRE5AAV/QR5Z+mm/csrNu1g256SJGYjIhKjop8Ae/d9/dbN8+6bzhUPfJSkbEREDlDRT4CL/jh9/3LF3OsrNu1MUjYiIgeo6CfAhm17k52CiEilVPSr8N7PzwqlnQ9XaPJ0EWk4VPSrkNuuBQ+NyuOsvuHN0XvOve/xLxPzQ2tPRORI1WS6xEfNbJOZLYqLtTWzaWa2PPjdJm7dODNbYWbLzOyCuPgQM1sYrBsfzJXboI3o34FWTTNDa29V0S7eWrIxtPZERI5UTc70HwMuPCg2Fnjb3XsDbwefMbP+wEhgQLDPnysmSgceAMYQmyy9dyVtNkie7AREREJUbdF39+nAloPClwETg+WJwOVx8afdvdjdVxObBP1kM+sEtHL3Ge7uwONx+zRo3z+lGwB3f+cbSc5ERKTuajum38HdCwGC3+2DeGdgXdx2BUGsc7B8cLxSZjbGzPLNLL+oqKiWKYbjlJ5Hs+buS/j24M50bdsslDZzx77Kib9+k6c+XhtKeyIiNRX2hdzKxun9MPFKufsEd89z97ycnPAupNZFVkY6H9xyDnd+a0Ao7X21u4TbJi8MpS0RkZqqbdHfGAzZEPyueOdAARD/trIuwIYg3qWSeKNzTr/Yl5r22VlJzkRE5MjVtuhPAUYHy6OBl+LiI80sy8x6ELtgOysYAtphZkODu3ZGxe3TqBzTuhmXHN+Jh0blcfOIPslOR0TkiNTkls1JwAygr5kVmNm1wN3ACDNbDowIPuPui4FngU+B14Eb3L3ihfLXAw8Tu7i7Engt5L7Ui/Q040/fG8wJXVtT7nW/t+cjPbwlIvUoo7oN3P2qKladW8X2dwF3VRLPBwYeUXYNXHl53Yv+9x7+mDV3XxJCNiIi1dMTuXUQQs0XEalXKvp1cEzr2C2cZ9fxVQ09x73Kwx+sYuN2vahNRBLLPIRx6UTKy8vz/PyG+b4ad+f9z4o4s08OPcZNDaVNDfWISBjMbI675x0cr3ZMX6pmZpzVt331G4qINBAa3mlgHnhvJQ3925eINF4q+g3M715fGtpQkYjIwVT0Q9KxVdNkpyAiUi0V/ZC8eMMw/u+ak8jr3qb6jWvg7SUb+e5fZlC4bY+Ge0QkNCr6Iel4VFPO7tee564/jZwQ3stz7cR8Pl69hVP/6x0e/fsaAJ7NX8fqL3fVuW0RiS4V/QR47+dn8c7Pzgytvb8Hr2q45bkFXPq/H4bWrohEj4p+ArTIyqBnTsvQ2vt884Gz+53FpaG1KyLRo6KfQBOuHhJKOyuLdvHy/Eb5JmoRaWBU9BPo/AEdGXtRv1Da+rdJc0NpR0SiTUU/wcac0ZPzjusAQO7RzUNps7i0jGVf7AilLRGJFhX9BEtLMx4encdDo/JCO+u/ffIiLvif6RTtKA6lPRGJDhX9ejKifweaZBz4527dPLPWbT03JzbH/G2TF5I79tVDJljfvLOYj1dtrnX7IpK6al30zayvmc2L+9luZjea2a/MbH1c/OK4fcaZ2QozW2ZmF4TThcbpD1cOqnMb0z7dCHDIBOvfnTCT706YWef2RST11Lrou/sydx/k7oOAIcBuYHKw+g8V69x9KoCZ9QdGAgOAC4E/m1l6nbJvZE7r1W7/8tn9wn07Z+7YV/ef3a/YtDPUtkUkdYQ1vHMusNLdPz/MNpcBT7t7sbuvJjZX7skhHb9RaJqZzvK7LmL+HecDcN2ZvUJtf+KMNV/77O4Ul5ZVvrGIRFJYRX8kMCnu84/NbIGZPWpmFS+j6Qysi9umIIgdwszGmFm+meUXFRWFlGLDkJmexlHNYuP5Yy/qR58OBx7iOr7LUXVqe+rCL1i3Zff+z7dNXkTfX7zO7DVb6tSuiKSOOhd9M2sCfAv4WxB6AOgFDAIKgXsrNq1k90rfJObuE9w9z93zcnLqNhVhQ/fmTQde1zC4W91f1nbGPe/uX540K3aBVxd1RaRCGGf6FwGfuPtGAHff6O5l7l4OPMSBIZwCoGvcfl0APWYqIlKPwij6VxE3tGNmneLWfRtYFCxPAUaaWZaZ9QB6A7NCOH7KSNQrlP/7zc8S0q6IND51miPXzJoDI4AfxoXvMbNBxIZu1lSsc/fFZvYs8ClQCtzg7rrKCDz1L6fw/mdFpKdVNgImIhIea+gTdOTl5Xl+fn6y06gXe0vKuG3yQob3zmHy3PW8/1m4F7EfGpXHiP4dQm1TRBomM5vj7nmHxFX0G67csa+G3ubPRvThn0/vQcusr3/J27JrH62aZpCRroe0RVJBVUW/TsM7klgj+nfY/9RtWO6d9hkfrdzMpDFDuX3yQp6Me4XDVSd35b++c3yoxxORhkWndQ3YQ6MO+SMdihmrNvPy/A1fK/gAk2at48PlXybkmCLSMKjoN3D3XHE8g7u15q/XnsKwY49myo+HhdJuVe/n//4jH4fSvog0TBrTb2SWFG7noj9+kPDjrPztxbqbSKQRq2pMX2f6jUy/jtn1cpxd+zQXr0gqUtFvZMyM8VedCMDLPz6dOy7tn5Dj7CstT0i7IpJcGt5ppLbtKdn/4jYI//bOkSd15Zx+7Vm+aSfXnt6DppmRegu2SKOn4Z0UE1/wAWbffl6o7T89ex1jnpjD799Yxh0vLQZgQcFWRj06S98CRBoxFf0UkZOdxciTYu+zC+sOnwrP5K+jpKycf//bAqZ/VsTKIk3SItJYqeinkLuvOJ6Vv72Y47u05qFRefxx5KDQ2p67disbtu4B4N+fm8/eEr02SaQx0ph+iisrd3rdNjX0dptlpvPsD09lZ3Ep7y3bxLiLjwv9GCJSe3oNQ0Ql6l77PSVlXHr/h/s/VxT9TzdsZ29pWSgTwohI+FT0I+CDW85mT0kZvXJaMvb5BfxtTkHoxygpKyczPY2Lx8ceHFtz9yWhH0NE6k5j+hHQtW1z+nTIJj3N+P0/nsApPdqGfozBv5l2SGzj9r2sD64DiEjDoDH9CFu/dQ/D7n6Hpplp7C0J9zbMu749kNsnxyZNm3/H+WSmG1kZ6Xq1g0g9Scj79M1sDbADKANK3T3PzNoCzwC5xGbOutLdvwq2HwdcG2z/E3d/o7pjqOgn3idrv+I7f/4o4cf5zomdue+7gxJ+HBFJ7IXcs909/n28Y4G33f1uMxsbfL7VzPoDI4EBwDHAW2bWR1MmJt/gbm147J9PYnD3NhTtKGZBwVZuemZ+6Md5Ye56zu7XnktPOCb0tkWkZhIxpn8ZMDFYnghcHhd/2t2L3X01sAI4OQHHl1o4q297WjXNpFdOS759Yhfm/ccIXvjRaaEf598mzaXgq90AvLawkCdmrKHPL15jV7Fe8CZSH+p6pu/Am2bmwF/cfQLQwd0LAdy90MzaB9t2BmbG7VsQxKQBat28CYO7NWHRnRfwxbY9nHff9NDaPv137x4SW7N5FwOOOYqK4UYzjf2LJEJdz/SHuftg4CLgBjMbfphtK/u/uNILCmY2xszyzSy/qCjcycHlyLTMyuDY9tm8ceNw7rkicVMpXjL+Q5Zv3EGPcVN5eva6hB1HJOrqdKbv7huC35vMbDKx4ZqNZtYpOMvvBGwKNi8Ausbt3gXYUEW7E4AJELuQW5ccJRx9O2bTt2M2J3ZrTdPMdLq2bc4pv32LjduLQzvGiD/Evk2Me2Ehizds46MVm5l8wzDS0+yQidxFpHZqfaZvZi3MLLtiGTgfWARMAUYHm40GXgqWpwAjzSzLzHoAvYFZtT2+JEfvDtl0bdscgL/feg5jhvcE4Budjwr1OH+duZZVX+7ihDvfZOAdb7Bx+1527C2huFTX/UXqota3bJpZT2By8DEDeMrd7zKzo4FngW7AWuAf3X1LsM/twA+AUuBGd3+tuuPols3Go6SsnOfnFPDYR2tY+sWOhBxjYOdWvPJvZySkbZFUkpD79OuDin7jtKRwOy/OW89f3l8FwHPXnco/PDgjlLZHn9qdOy4dwBfb9zJr9RbO7tf+kPkFRKJORV+S7tnZ67jl+QUJafutm4fTuXVz1m7ZTa+cFjz84WquOS1XM35JZKnoS4Pyt/x1nNrr6P23b7ZqmsH2veHcq/+LS47jP19dwvA+OYw+tTvnHteBmas2c3JuW9L0GgiJCBV9aZC27S6hRVY6GelpvLawkHGTF7J1d0lCjjXs2KOZcHUeLXQnkESAir40Oq8s2EDn1s3ISEvjzpcXk//5V6G0+83jO3Hjeb05777p/PDMnlw9tDtd2jQPpW2RhkJFX1LCnM+38Nyc9UyatTYh7f/DkC48N6eAR0bnce5xHRJyDJH6oKIvKaVijH715l0sLdxBy6YZjHt+ARu27Q3tGMe2b8lJuW343sndKXcnt10L3SUkjYaKvkRC7AGucgwY9egsikvLWbFpZ6jH+PvYc8hpmcXjM9Zw6QnHsGNvCVkZ6fsfWvti214cp9NRzUI9rsiRUNGXyFtSuJ3WzTNZWLCNMU/MScgxBnZuxaL12wGYMe4cdhWXcWz7lgk5lsjhqOiLVGJncSktszL4fPMu1m/dQ0ZaGlf+JZyHyCrzm8sG0KVNcwZ1bU2zJrFnCJpmplNe7rqdVEKloi9SS6Vl5bzwyXrWbN7FksLtpKel8daSjQk95vPXnxq8apr9fxxEjkQiZ84SSWkZ6WlceVLXStftKi6leZN0VhbtZMbKzeRkZ/Hi3A0Ul5bx7rLavxb8igcO/22jfXYWF3+jE306ZNOjXQuG9oxNdq95CKQ6OtMXqUd7S8rYvreEj1Zs5pUFhaQZvPlpYr41DOramvMHdMAdhvY8mt37SunWtjnHtI5dYE4305BSCtPwjkgjsnH7XvaVllNSVs6Lc9ezZvNupsyvdPqJUJzTrz1jhvekT4ds1n+1h94dWpKVkUZpuZORZvoG0Qip6IukmC279pGeZry9ZCOF2/by9Oy1rNuyJ2HHG9G/A0O6t+F7p3SjVVM9r9DQqeiLRMjO4lI+37yLPh2yeXVBIX07ZnPRHz9I2PF+fn4fjm2fTbe2zel/TCv2lpTRNDOdfaXlbN5VfETPLLg7byz+gnOP60Bmel1ndI0uFX0RAWJF9bVFX3B673bc9PQ83l66ibP65vBeHS4819RrPz2DppnptGqawVtLNjK4Wxu6Hd2cNDMu/d8PGXtRP/aVljPmiTncPKIPPzm3d8JzSlWhF30z6wo8DnQEyoEJ7v5HM/sV8K9AxX9Bt7n71GCfccC1QBnwE3d/o7rjqOiLJMevpizmsY/WJO34Vw/tzm8uHwjAg++vZNbqLTx6zUlJy6exSUTR7wR0cvdPgrly5wCXA1cCO939vw/avj8widjk6ccAbwF93P2wk56q6Iskz559ZRR8tZveHbLZW1LG+X+YTt+O2UyLu+OoV04LVhbtqte8mmSk8e/n9+Vfh/dk2+4S7p22jMdnfM5HY89h+aadnNitNS2aZLD0i+1s31PKqb2OpqSsPFJ3LCV8eMfMXgLuB4ZRedEfB+Du/xV8fgP4lbsf9oZkFX2RhmvTjr20z24KxF6Cd+1js3l49Elc9dDMJGf2dX/63mBueOoTAB78/hCu++scbr2wH92Pbs7pvdvRqmkmu4pLyUg3sjIqfxiuuLSMzLS0Sv9obNi6h6OaZTaouRoSWvTNLBeYDgwEbgauAbYD+cDP3P0rM7sfmOnufw32eQR4zd2fO1zbKvoijc/bSzays7iUjLQ0crKz2LJrH13bNuOS8R/u36ZDqyw2bi9OYpaV++GZPfnL+6u4/3snsm7LHk7ochSDu7eh3y9f/9p2U39yBu2ym/Dphu1c83+zAVhz9yUAlJc7Zgcelnt36SaG5Lap17ueElb0zawl8D5wl7u/YGYdgC8BB35DbAjoB2b2J2DGQUV/qrs/X0mbY4AxAN26dRvy+eef1ylHEWkYikvLWFiwjSHd2+wviO5OcWk5melpjH97OSs27eSib3Tkx0/NpV/HbJZ+sSPJWdfcid1aM3ft1v2ffzCsBy8v2EDRjmLO6N2OWy/sx50vL+aa03owvE87Nu/cx659paSnGf06tgo1l4QUfTPLBF4B3nD3+ypZnwu84u4DNbwjIrVVVu4s3rCNAcccRXowvLK3pIyPV2+hcOseHnx/JWs27+ayQcfwxba9fLx6S5IzPnJ//qfBbN1dwtOz19KnQzY3j+iz/+np2kjEhVwDJgJb3P3GuHgndy8Mlm8CTnH3kWY2AHiKAxdy3wZ660KuiIRtV3Hp/vH1ncWl3PP6Uq4Y3IVj27dk7ZbdCX1mIUwVw0W1kYgXrg0DrgYWmtm8IHYbcJWZDSI2vLMG+CGAuy82s2eBT4FS4IbqCr6ISG3EX1BtmZXBry8buP/zcZ1asebuS9hbUkaT9NiF2b0lZaSn2f5XTrzwSQH/89Zy3rxpOCVl5Qz69TTKyp0pPx7GpFnreGb2Wsrr4RGnRLxyWw9niYhUY19pOWXl/rXXXBd8tZsubZp/bTt3Z/aar2jdPJNeOS1JM3CH305dQkZ6Guu37iGnZRY52Vnc88ZSqiu/c385gjYtmtQqZz2RKyLSAC1av41fTVnM+QM6MOrUXCZ+tIYWWRl8f2j3OrWroi8iEiFVFX29zUhEJEJU9EVEIkRFX0QkQlT0RUQiREVfRCRCVPRFRCJERV9EJEJU9EVEIqTBP5xlZkVAbd+t3I7Ya56jRH2Ohqj1OWr9hbr3ubu75xwcbPBFvy7MLL+yJ9JSmfocDVHrc9T6C4nrs4Z3REQiREVfRCRCUr3oT0h2AkmgPkdD1Poctf5Cgvqc0mP6IiLydal+pi8iInFSsuib2YVmtszMVpjZ2GTnUxdm9qiZbTKzRXGxtmY2zcyWB7/bxK0bF/R7mZldEBcfYmYLg3XjgzmOGyQz62pm75rZEjNbbGY/DeIp228za2pms8xsftDnO4N4yvYZwMzSzWyumb0SfE71/q4Jcp1nZvlBrH777O4p9QOkAyuBnkATYD7QP9l51aE/w4HBwKK42D3A2GB5LPC7YLl/0N8soEfw75AerJsFnAoY8BpwUbL7dpg+dwIGB8vZwGdB31K230F+LYPlTOBjYGgq9znI9WbgKeCViPy3vQZod1CsXvucimf6JwMr3H2Vu+8DngYuS3JOtebu04EtB4UvAyYGyxOBy+PiT7t7sbuvBlYAJ5tZJ6CVu8/w2H8xj8ft0+C4e6G7fxIs7wCWAJ1J4X57zM7gY2bw46Rwn82sC3AJ8HBcOGX7exj12udULPqdgXVxnwuCWCrp4O6FECuQQPsgXlXfOwfLB8cbPDPLBU4kduab0v0OhjrmAZuAae6e6n3+H+AWoDwulsr9hdgf8jfNbI6ZjQli9drnjFom3pBVNrYVlVuUqup7o/w3MbOWwPPAje6+/TDDlinRb3cvAwaZWWtgspkNPMzmjbrPZvZNYJO7zzGzs2qySyWxRtPfOMPcfYOZtQemmdnSw2ybkD6n4pl+AdA17nMXYEOSckmUjcFXPILfm4J4VX0vCJYPjjdYZpZJrOA/6e4vBOGU7zeAu28F3gMuJHX7PAz4lpmtITYEe46Z/ZXU7S8A7r4h+L0JmExsOLpe+5yKRX820NvMephZE2AkMCXJOYVtCjA6WB4NvBQXH2lmWWbWA+gNzAq+Mu4ws6HBVf5Rcfs0OEGOjwBL3P2+uFUp228zywnO8DGzZsB5wFJStM/uPs7du7h7LrH/R99x9++Tov0FMLMWZpZdsQycDyyivvuc7KvZifgBLiZ2x8dK4PZk51PHvkwCCoESYn/hrwWOBt4Glge/28Ztf3vQ72XEXdEH8oL/wFYC9xM8mNcQf4DTiX1dXQDMC34uTuV+A8cDc4M+LwL+I4inbJ/j8j2LA3fvpGx/id1ROD/4WVxRm+q7z3oiV0QkQlJxeEdERKqgoi8iEiEq+iIiEaKiLyISISr6IiIRoqIvIhIhKvoiIhGioi8iEiH/D1NIdrSLFL4EAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"base_regression_model = BayesianRegression(in_features, out_features)\n",
"train(base_regression_model, regression_guide_with_decorator, x_data, y_data, dataset, adam_params)"
]
},
{
"cell_type": "markdown",
"id": "ba1b13b6",
"metadata": {},
"source": [
"## Using AutoNormal"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ae4845aa",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [01:12<00:00, 69.17it/s]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAee0lEQVR4nO3deXwV9b3/8dcneyDsJDEENIAoggsiIopaF0TUtqj9taXtrcu15XftcmttbcHWe23VW2ur9Vqrv1qsyr3WpVUrFbXiiiiCQUF2CBAhJpKNQEgg2/n+/siABwhkOXNyTs68n4/HeWTO98x85/Ol9p3Jd+bMmHMOEREJhqRYFyAiIt1HoS8iEiAKfRGRAFHoi4gEiEJfRCRAUmJdQHsGDx7sCgoKYl2GiEiPsmzZskrnXPbB7XEf+gUFBRQWFsa6DBGRHsXMPm6rXdM7IiIBotAXEQkQhb6ISIAo9EVEAkShLyISIAp9EZEAUeiLiARIQof+tup63lxfHusyRETiRtx/OSsS59z1BgDFd14W40pEROJDQh/pi4jIgRT6IiIBotAXEQmQhA39bdX1sS5BRCTuJGzo7zuJKyIin0nY0BcRkUMp9EVEAkShLyISIAp9EZEAUeiLiARIIEL/3aLKWJcgIhIXAhH6S7ZUx7oEEZG4EIjQd7EuQEQkTgQi9LdU1sW6BBGRuBCI0C+t2RPrEkRE4kIgQr+mvjHWJYiIxIV2Q9/MhpnZG2a21sxWm9kPvPaBZrbAzDZ6PweEbTPbzIrMbL2ZXRzWfpqZrfQ+u8/MLDrDOtCmCk3viIhAx470m4EfOedOACYB3zWzMcAs4DXn3CjgNe893mczgLHANOABM0v2+noQmAmM8l7TfByLiIi0o93Qd86VOec+8JZrgbVAPjAdeMxb7THgcm95OvCkc67BObcFKAImmlke0Nc5t9g554C5YduIiEg36NScvpkVAKcCS4Bc51wZtP5iAHK81fKBbWGblXht+d7ywe1t7WemmRWaWWFFRUVnSjws5xx/LdxGY3PIl/5ERHqiDoe+mWUBzwA3OOd2HWnVNtrcEdoPbXTuIefcBOfchOzs7I6WeET/+KiMm/72Efe/UeRLfyIiPVGHQt/MUmkN/Medc896zdu9KRu8n+VeewkwLGzzoUCp1z60jfZusXNPEwBVuxu6a5ciInGnI1fvGPAwsNY5d0/YR/OAq73lq4Hnw9pnmFm6mQ2n9YTtUm8KqNbMJnl9XhW2jYiIdIOUDqwzGfgmsNLMlnttNwN3Ak+b2XXAVuDLAM651Wb2NLCG1it/vuuca/G2ux54FMgEXvJe3eK2F9Z0165EROJWu6HvnFtE2/PxABceZps7gDvaaC8ETuxMgX7RCVwRkYB8I1dERFop9EVEAiRwoV9UvjvWJYiIxEzgQn/JlmpdtikigRW40Ae4+N6FFMyar8coikjgBDL0K3e33mr59XXl7awpIpJYAhn6+3TPjZ1FROJHoEPf6eG5IhIwgQ79OYu2xLoEEZFuFejQFxEJGoW+iEiAKPRFRAJEoS8iEiAKfRGRAFHoi4gEiEJfRCRAFPoiIgGi0BcRCRCFvohIgAQ+9GvqG2NdgohIt2n3weiJ7v7Xi0hONiYWDOTCE3JjXY6ISFQFPvT33XTtj29tpvjOy2JcjYhIdCXs9M6kEQNjXYKISNxJ2ND/t8+NjHUJIiJxJ2FD/3PHZce6BBGRuJOwoW96FqKIyCESNvRFRORQCn0RkQBR6IuIBIhCX0QkQBT6IiIBotAXEQkQhb6ISIAo9EVEAkShH2bFtppYlyAiElUK/TDvbKqMdQkiIlGV0KE/918nxroEEZG4ktChn5LUufvvOBelQkRE4kS7oW9mfzazcjNbFdZ2q5l9YmbLvdelYZ/NNrMiM1tvZheHtZ9mZiu9z+4z3RFNRKTbdeRI/1FgWhvtv3POjfNeLwKY2RhgBjDW2+YBM0v21n8QmAmM8l5t9emrzh64/+af66NSh4hIvGg39J1zC4HqDvY3HXjSOdfgnNsCFAETzSwP6OucW+ycc8Bc4PIu1txhef0yALhhyqhOb/vB1h0s19U8IpJgIpnT/56ZfeRN/wzw2vKBbWHrlHht+d7ywe1tMrOZZlZoZoUVFRVdLnBEdhZv3XQe/35B50P/ygfe5fI/vNPlfYuIxKOuhv6DwEhgHFAG3O21tzVP747Q3ibn3EPOuQnOuQnZ2ZE9AeuYQb1J6uQJXRGRRNWl0HfObXfOtTjnQsCfgH3XRpYAw8JWHQqUeu1D22iPSytLdsa6BBGRqOhS6Htz9PtcAey7smceMMPM0s1sOK0nbJc658qAWjOb5F21cxXwfAR1d9qIwb07tN7ybTV84f5FUa5GRCQ2OnLJ5hPAYuB4Mysxs+uAu7zLLz8Czgd+COCcWw08DawBXga+65xr8bq6HphD68ndTcBLfg/mSF698XMdWq+0Zk+UKxERiZ2U9lZwzn2tjeaHj7D+HcAdbbQXAid2qjofaV5fRCTBv5HbFSF9LVdEEligQv+n00a3u84Pn1oe/UJERGIkUKHfEU0tOtIXkcQVqNAf0j8j1iWIiMRUoEL/i6cMiXUJIiIxFajQ78qNPct27uFXL60lFNK0j4j0fIEK/a744VPL+eNbm/lw245YlyIiEjGFfjv2ndjVlZwikggU+u3Y29TS/koiIj2EQr8dq0t3xboEERHfBDb0O3oDNhGRRBK40H/m+jMB+MM3xse4EhGR7he40D/tmIEU33kZJ+T17dR2/+f/LaZsp+7AKSI9W+BCPxKzn125f3nO25v5y5KtMaxGRKTzFPqd8Ob6Ct7bXAXA7fPXcvNzK9vZQkQkvij0O+mZZSXtryQiEqcU+iIiAaLQ7yR9MVdEejKFfict2VIV6xJERLpMod9J26p12aaI9FwKfRGRAAl06Of3z4x1CSIi3SrQof/sd87ih1OOi3UZIiLdJtChn9s3g3OPGxzrMkREuk2gQx+gK09BPPbmF/0vRESkGwQ+9F0XHonVrOflikgPFfjQ752eAsCV4/MZ0i+j09v/TbdlEJEeJPChf0JeX+ZcNYHbLz+xS9v/+K8rfK5IRCR6UmJdQDyYMiY31iWIiHSLwB/ph/u380ZGtP2qT3ayrbrep2pERPyn0A9z1ZkFEW3/+d8v4py73vCnGBGRKFDo+yCkq3lEpIdQ6Pvg1/9cR0VtQ6zLEBFpl07k+uCPb21mZcnOWJchItIuHen7REf6ItITKPR9srF8d6xLEBFpl0JfRCRAFPoHyevCrRjaUzBrPtc8stT3fkVEOqvd0DezP5tZuZmtCmsbaGYLzGyj93NA2GezzazIzNab2cVh7aeZ2Urvs/vMzPwfTuRevfFz+5fX3TaNL40fyvij+3eqj531TYe0vbm+ItLSREQi1pEj/UeBaQe1zQJec86NAl7z3mNmY4AZwFhvmwfMLNnb5kFgJjDKex3cZ1zYdwM2gIzUZO7+yink9Onc0f+KkhqfqxIR8Ue7oe+cWwhUH9Q8HXjMW34MuDys/UnnXINzbgtQBEw0szygr3NusWu9l/HcsG3izsGPUXR07stX+qqWiMSrrs7p5zrnygC8nzleez6wLWy9Eq8t31s+uL1NZjbTzArNrLCiovunReZ9bzLzvjd5//vO3nLfOcd/v7qRglnzfa5MRCQyfp/IbWue3h2hvU3OuYeccxOccxOys7N9K66jBmWlc/LQ/p/V08ntHfC7VzcA8Og7W/a3V+7WtfwiEltdDf3t3pQN3s9yr70EGBa23lCg1Gsf2kZ7jzBsQK9Ore+cY99p6lv/sWZ/+4TbX6W5JeRnaSIindLV0J8HXO0tXw08H9Y+w8zSzWw4rSdsl3pTQLVmNsm7aueqsG3i3k8vOb5T679bVHXYKSE9alFEYqkjl2w+ASwGjjezEjO7DrgTuMjMNgIXee9xzq0GngbWAC8D33XOtXhdXQ/MofXk7ibgJZ/HEjXpKcl8/uQ8fvvlUzq0/pxFW9pfSUQkBtq94Zpz7muH+ejCw6x/B3BHG+2FQNeeSRgH7v/6eKp8mJPvwnPYRUR8o2/kdrPOXv4pIuInhX4nZKQmH/D+2skFne5jzH/8k5/8TQ9TF5HYUOh3Qu/0FF6+4Zz972++9IQu9fN0YQmX3fe2X2WJiHSYQr+TRh/Vd/9yanLX//lWl+7yoxwRkU5R6MdQwaz5jLz5xViXISIBoscldsHi2RfQK82ff7qWkGPeilK+eMoQX/oTETkSHel3QV6/TPplpvrW3w1Pfsi7RZUUzJrPJzV7fOtXRORgCv0Izb5kdMR9pCYn8fU5SwAoLD74hqYiIv5R6MeB8BPCG7frWbsiEj0K/QhlZUQ+t7+7oXn/8v1vFEXcn4jI4Sj0IzTj9KNjXYKISIcp9COUnGRMHZMLwH9dcZJv/RbMmk/BrPls3F7rW58iIgp9H9zz1XH85dtn8PUzjuYv3z7D175//vdV7a8kItJBCn0fZKWncNbIwQCcmN8v4v7ufmV9xH2IiLRFoe+zzLCbst31pZO71MfvX//sZO6qT3YC8D+Li9laVR9ZcSISeAp9n4VffpnbLyPi/uoaW9jT2MItz6/mK39cHHF/IhJsCv0o6N+r9du6zqcnpuy7B3/NnkZf+hOR4NK9d6Jgyc0X4hws2ljpS39X/3kpACEH1XWNJBn075XmS98iEiw60o+C9JRkMlKTOWZQrwPab7+8a0+LfL94BwCNzSHG37aAcb9cEHGNIhJMCv0oGpXbh8KfT9n//kvjh/q+j90NzSzZXOV7vyKSmBT6UTY4K53BWa1TMZlpye2s3XHFlXUA/PsTH/LVh95j+669vvUtIolLod8NFv30AtbdNu2AtkeuPT2iPs/77Zu8ub6c19eVA3DGf71GTb1O9IrIkSn0u0FGavIhD1U///gcbr40stsyX/PI+we8/+ULayLqT0QSn0I/hq47e4Sv/TU2h3ztT0QSjy7ZjKHkJOMn047nrpf9ue3CCx+VsXH7QtZvr+WC0TlMHzeE6ePyfelbRBKD+fUFomiZMGGCKywsjHUZvtm1twkXgn69Pnvc4uQ7X4/aYxKL77wsKv2KSHwzs2XOuQkHt+tIv5v1zTj02bpPzpzEoqJKRgzuzVOF23j2g0983ee6T3exa08zE4cP9LVfEel5NKcfB4YN7MXXJh7NGSMGMf7oAb72XbW7gWn3vq379ogIoNCPO2b+9nfa7a/626GI9GgK/TgzbEDrrRv8Dn+AtzZUsKNO1/KLBJlO5MahZR/v4NRh/VlRUsMVD7zra9+Deqfx3Hcmc/RB9wUSkcRyuBO5OtKPQ6cdM4CkJGPcsP5cd/ZwRh/VhyvH55PfPzPivqvqGjn3N28A8HFVHbsbmiPuU0R6Dh3p9yA3PrWcZz/098qek/L78Y/vn82E218lp086L/7gHF/7F5HY0JF+Arj+vJEcM6gX782+0Lc+V36yk4JZ86nc3cCasl2+9Ssi8UnX6fcgo3L78NZN5xMKRe+vs2UfV5ORmszYIZE/4F1E4o+O9HsgM7ji1HwG9fb/6VlfenAxl923iMWbqqjTfL9IwlHo90Bmxu++Oo5lt1zEvV8dB8AD3xjPJScexY+nHufLPr72p/f40dMr9r9fsGY79Y36JSDS02l6p4e7/NR8pozJJSs9hUtPygPgt69s8KXvDeW1QOttHL49t5DLxw3h3hmn+tK3iMRGREf6ZlZsZivNbLmZFXptA81sgZlt9H4OCFt/tpkVmdl6M7s40uKlVVb6gb+7Lz3pKF/63VxRR8Gs+azYVgPAx9X1VNc1snhTVVTPK4hI9ER0yaaZFQMTnHOVYW13AdXOuTvNbBYwwDn3UzMbAzwBTASGAK8CxznnWo60D12y2XnOOUIO6hubWbK5mm/NLeTc47JZuKHC1/1cMDqHP18T2RPARCQ6uvOSzenAY97yY8DlYe1POucanHNbgCJafwGIz8yM5CSjT0YqU8bksuI/p/Lw1Yf8bx+x19eV8+Cbm3DO8cb6clp09C8S9yINfQe8YmbLzGym15brnCsD8H7meO35wLawbUu8NomyfpmppCYnseoXFzOwdxoDfbzq59cvr2P47Be59pH3eXjRZt/6FZHoiPRE7mTnXKmZ5QALzGzdEdZt6xZibR4aer9AZgIcffTREZYo+2Slp/DBLRcBUF67lxufWsH3LziW2r3NfGtu5FNoTyzdxt6mEH9auJnHrpvo+22iRSRyEYW+c67U+1luZs/ROl2z3czynHNlZpYHlHurlwDDwjYfCpQept+HgIegdU4/khqlbTl9Mvjfb53ha59bKuu4Z0HrlUNfevBd/vD18WSmJnP+6Jx2thSR7tLl6R0z621mffYtA1OBVcA84GpvtauB573lecAMM0s3s+HAKGBpV/cv/vro1qnc85VT2PKrSzkhr2/E/TkH33n8A6599H3mvL2Zu19Zz5bKOgBue2ENl933dsT7EJHOi2ROPxdYZGYraA3v+c65l4E7gYvMbCNwkfce59xq4GlgDfAy8N32rtyR7tM3I5Urxw/FzJj3vckHfPanqyI7CXz7/LX8/vUizv/tm2yu2M3Di7awulT3+RGJhS5P7zjnNgOntNFeBbR5RzDn3B3AHV3dp3SPZDNy+6Zz7qhshvTP5KIxuTx67elc88j7Efd9wd1v7V/eUllHY3MIh2NvU4hxw/pH3L+IHJlurSwd0twS4vb5a7nu7OEMG9iLglnzfd/HzZeOZua5I9mwvZaj+mW0+RB5EemYw12nr9CXLpk5t5AxQ/ry5QnDGNIvg+GzX/Sl3+w+6VTUNjAiuze3fmEs54waDLR+90BEOk6hL1EVCjk+rq7HOceI7CwAzrnrdbZV7/Gl/zuuOJHJIwdTMLi3L/2JJDqFvnS7huYWKmob6JeZyuJNVcz8n2W+9r/wpvMPedbve5urODYni8FZ6b7uS6SnOVzo6y6bEjXpKckMHdAaylPHtt4EbmLBQO6dMY75H5Vxx4trI+p/37N+97ny1Pz9j5N8/UefIyUpSQ+AFzmIjvQlZorKdzNvRSmLNlbwk2mjmfHQe77vY3BWOmeMGMgNF45iQO80WkKOkh17GJWbpRPFktA0vSM9xq3zVvPe5iq+MekYbvn7Km75/Bhue2FNVPZV+PMppCYn0dwS4r3N1Uwdm8vW6npGeuclRHoqhb4khH+sKOWeBRuYNGIgTyzd1v4GESgY1Iu7v3IKef0y6ZWWjGH066W/DqRnUOhLwtl3xdDwwb2pa2jmlTWf8vh7WxmR3Zua+iZeWbM9qvt/5NrTqW9o4fzR2fRK0+kxiS8KfQmcxuYQe5pa2FpVT2ZaEvNWlJGeksSCNdtxsP+JYNHwzPVnsaWyjj4ZKZw7KpuUZCPZjKQkfd9AuodCX6QNqz7ZyfFH9WFHXSMlNXtY/2ktpx0zgKm/Wxj1fQ/OSudfzy7gmrMK9JeC+E6hLxKhppYQzS2OxuYQVXUNPF1Ywpvry5lyQi73v1Hk676OzcmiqSXEL6efyKQRA0lLTtK3kqVTFPoi3WRvUwvpKUns3NNEdV0ja8p28di7xbxfvMOX/r94yhCq6xrJ6ZPONZNb/0o4NkdXG8mBFPoicSIUcvvn9str99ISctzzygb+uqzEt32cXjCAX04/kWNzskhNjsajsCXeKfRFepDdDc28tb6CY3OyWF26kxufXhFxn/n9M7n9ihMZNiCT1OQkMtOSyemT4UO1Eo8U+iIJwDlHeW0D355biAGVuxv5pKbrN7UbOiCTkh17+OX0sZx97GB6paUwKCtNfx0kAIW+SILa29RCQ1OIjLQkPtxa49vtLAoG9aK4qh6AJTdfyIBeaVTXNfL35Z/wf88doRPLcU6hLxIQTS0hmlpCJCcZZTV7eXzJxxx/VF9+/NfIp4j2uWB0Dt86ZziVuxv5wsl5mBnFlXUMzErTPY3ihEJfRGhuCbF8Ww2nHTOA0p17mXzn61Hb16QRA/nvGaeS27f1vMEO70qmyccOjto+5TMKfRE5hHOOj6vqKRjcm5aQo7RmD7/553rmrSiN2j5/deVJDOqdRv6ATP5lzhL+8wtjmTo2l15pKbSEHK+u3c7UMbkHTB81NodIS9F5hs5Q6ItIp7xfXE1pzR4uPSmPt9ZX8PbGCt7ZVEVR+e6o7C8tOYnGlhDQeoJ59iUn8OHWHRx/VB9u+ttHAKz+xcX0Tj/w28vNLSHKdu4lKz2FAb3TolJbT6TQF5GIOeeoqG0gp28GZTv3cOavWqeHfnDhKPpkpHD7/MgejNMZ4Q/N2WfZz6cwqJNPTZu7uJglm6v5wzfG+1lezCn0RcR3u/Y24Rz0yzz05O3ephZeWlXGiys/ZUGU73h6sDF5fVlTtguAH089ji+cMoSqukbe2VjJsTlZXHJSHjvqGumbmcrIm18EoPjOy7q1xmhT6ItIzOzLGTPDOUfl7kZ6pyeztylEXUMz59z1BslJRn7/TLZW10e9nmvOKuDRd4vb/CyvXwav/PBcNlfUcVJ+P3buafJt2mhrVT0Pvb2JX3zxRJKjfMdVhb6I9AjOtT7Ssqh8N/16pZKdlc45d7U+D3lEdm82V9TFrLYzhg9kygm5vLmhnDlXnc5ji4s57/hsRh/Vt0PbX/nAO3ywtYbnvnMWpx49IKq1KvRFJGGs2FbDyJwsstJbr/hZuqWa6x9fRk19U0zq+drEo0lPSWLogEwuOSmP1Z/sZE9TCxW1DeT3z2TKmFwWbqjghieXU9vQzDPXn8lpxwwEoKG5hQ2f7uakof18rUmhLyKBFAo5Rv7sRZyDp2ZOYsP2Wm55fnWsywLgue+cxV+XlfCXJVtZeNP5bNheS0ZqMmePivy7DAp9EZEwzjlCDpIMtlbXYxgtzrF0SxV5/TK59R+rYzaVNOeqCVwwOieiJ60p9EVEuuiFj0rJzkonOcmormvkP55fzae79kZ9v5FcUXS40Ncz2kRE2vH5k4cc8H7q2KMOWee9zVV8unMvL64s47KT83hjXTm7G1p4de12kgxCXTi+bgk536/yUeiLiPhg0ohBAFx+aj4A08flH7JOYXE1O+qbeKeokgtPyOGbDy89Yp+NzSEy05J9rVOhLyLSTSYUtF6xc9GYXABumz6WYwb15tzjsgFYuKGCuYs/5vsXHMspw/pHpQaFvohIjHzzzIID3p97XPb+XwDRotvWiYgEiEJfRCRAFPoiIgGi0BcRCRCFvohIgCj0RUQCRKEvIhIgCn0RkQCJ+xuumVkF8HEXNx8MVPpYTk+gMQdD0MYctPFC5GM+xjl3yDe94j70I2FmhW3dZS6RaczBELQxB228EL0xa3pHRCRAFPoiIgGS6KH/UKwLiAGNORiCNuagjReiNOaEntMXEZEDJfqRvoiIhFHoi4gESEKGvplNM7P1ZlZkZrNiXU8kzOzPZlZuZqvC2gaa2QIz2+j9HBD22Wxv3OvN7OKw9tPMbKX32X1m5u+DN31kZsPM7A0zW2tmq83sB157wo7bzDLMbKmZrfDG/AuvPWHHDGBmyWb2oZm94L1P9PEWe7UuN7NCr617x+ycS6gXkAxsAkYAacAKYEys64pgPOcC44FVYW13AbO85VnAr73lMd5404Hh3r9DsvfZUuBMwICXgEtiPbYjjDkPGO8t9wE2eGNL2HF79WV5y6nAEmBSIo/Zq/VG4C/ACwH5b7sYGHxQW7eOORGP9CcCRc65zc65RuBJYHqMa+oy59xCoPqg5unAY97yY8DlYe1POucanHNbgCJgopnlAX2dc4td638xc8O2iTvOuTLn3Afeci2wFsgngcftWu323qZ6L0cCj9nMhgKXAXPCmhN2vEfQrWNOxNDPB7aFvS/x2hJJrnOuDFoDEsjx2g839nxv+eD2uGdmBcCptB75JvS4vamO5UA5sMA5l+hjvhf4CRAKa0vk8ULrL/JXzGyZmc302rp1zIn4YPS25raCcl3q4cbeI/9NzCwLeAa4wTm36wjTlgkxbudcCzDOzPoDz5nZiUdYvUeP2cw+D5Q755aZ2Xkd2aSNth4z3jCTnXOlZpYDLDCzdUdYNypjTsQj/RJgWNj7oUBpjGqJlu3en3h4P8u99sONvcRbPrg9bplZKq2B/7hz7lmvOeHHDeCcqwHeBKaRuGOeDHzRzIppnYK9wMz+l8QdLwDOuVLvZznwHK3T0d065kQM/feBUWY23MzSgBnAvBjX5Ld5wNXe8tXA82HtM8ws3cyGA6OApd6fjLVmNsk7y39V2DZxx6vxYWCtc+6esI8Sdtxmlu0d4WNmmcAUYB0JOmbn3Gzn3FDnXAGt/x993Tn3LyToeAHMrLeZ9dm3DEwFVtHdY4712exovIBLab3iYxPws1jXE+FYngDKgCZaf8NfBwwCXgM2ej8Hhq3/M2/c6wk7ow9M8P4D2wTcj/dt7Hh8AWfT+ufqR8By73VpIo8bOBn40BvzKuA/vPaEHXNYvefx2dU7CTteWq8oXOG9Vu/Lpu4es27DICISIIk4vSMiIoeh0BcRCRCFvohIgCj0RUQCRKEvIhIgCn0RkQBR6IuIBMj/ByEg1m2hDVo7AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"base_regression_model = BayesianRegression(in_features, out_features)\n",
"auto_guide = pyro.infer.autoguide.AutoNormal(base_regression_model)\n",
"train(base_regression_model, auto_guide, x_data, y_data, dataset, adam_params)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment