Skip to content

Instantly share code, notes, and snippets.

@Sayam753
Last active May 1, 2021 07:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sayam753/e449f44c0d4e1d9070805cf6728a1b1d to your computer and use it in GitHub Desktop.
Save Sayam753/e449f44c0d4e1d9070805cf6728a1b1d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d22371e8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.6.0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import pyro\n",
"import torch\n",
"import torch.nn as nn\n",
"import pyro.distributions as dist\n",
"from pyro.contrib.easyguide import easy_guide, EasyGuide\n",
"from pyro.nn import PyroModule, PyroSample, PyroParam\n",
"from pyro.distributions import constraints\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"torch.manual_seed(42)\n",
"pyro.set_rng_seed(42)\n",
"pyro.__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "88d4cef4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([100, 1]) torch.Size([100])\n"
]
}
],
"source": [
"# Setup data\n",
"x_data = np.linspace(0, 10, 100)\n",
"ep = 0.5 * np.random.randn(x_data.shape[0])\n",
"y_data = 5*x_data + 0.1 + ep\n",
"x_data = x_data[:, None]\n",
"y_data = y_data\n",
"x_data = torch.tensor(x_data).type(torch.float32)\n",
"y_data = torch.tensor(y_data).type(torch.float32)\n",
"print(x_data.shape, y_data.shape)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e5ee1da1",
"metadata": {},
"outputs": [],
"source": [
"class BayesianRegression(PyroModule):\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.linear = PyroModule[nn.Linear](in_features, out_features)\n",
" self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))\n",
" self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))\n",
"\n",
" def forward(self, x, full_size, y=None):\n",
" sigma = pyro.sample(\"sigma\", dist.Uniform(0., 10.))\n",
" mean = self.linear(x).squeeze(-1)\n",
" # since I am passing the x and y from a batch, I think I need to pass x.shape[0] to\n",
" # subsample_size as a proxy for batch_size. And do not use the random indices from pyro.plate\n",
" # context manager because I already have the data at hand. Am I right thinking this way?\n",
" with pyro.plate(\"data\", size=full_size, subsample_size=x.shape[0]):\n",
" obs = pyro.sample(\"obs\", dist.Normal(mean, sigma), obs=y)\n",
" return mean"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a0700b08",
"metadata": {},
"outputs": [],
"source": [
"in_features = 1\n",
"out_features = 1\n",
"adam_params = {\"lr\": 0.0005, \"betas\": (0.90, 0.999)}"
]
},
{
"cell_type": "markdown",
"id": "09f69cd9",
"metadata": {},
"source": [
"## Base Setup for custom loop"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dd1a72d1",
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, y):\n",
" super().__init__()\n",
" self.x = x\n",
" self.y = y\n",
" \n",
" def __len__(self):\n",
" return self.x.shape[0]\n",
" \n",
" def __getitem__(self, index):\n",
" return self.x[index], self.y[index]\n",
"\n",
"dataset = Dataset(x_data, y_data)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f5a27f60",
"metadata": {},
"outputs": [],
"source": [
"def train(model, guide, X, Y, dataset, adam_params, n_epochs=5000):\n",
" # X, Y used only for parameter initialization\n",
" pyro.clear_param_store()\n",
" torch.manual_seed(42)\n",
" pyro.set_rng_seed(42)\n",
" \n",
" # Get params\n",
" with pyro.poutine.block(), pyro.poutine.trace(param_only=True) as param_capture:\n",
" guide(x=X, full_size=X.shape[0], y=Y)\n",
" params = list([pyro.param(name).unconstrained() for name in param_capture.trace])\n",
" \n",
" # Train\n",
" optimizer = torch.optim.Adam(params, **adam_params)\n",
" loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n",
" losses = []\n",
" for epoch in tqdm(range(n_epochs)):\n",
" epoch_loss = []\n",
" for batch in torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True):\n",
" x, y = batch\n",
" loss = loss_fn(model, guide, x, X.shape[0], y)\n",
" epoch_loss.append(loss.item())\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" losses.append(sum(epoch_loss) / len(epoch_loss))\n",
" \n",
" plt.plot(losses)"
]
},
{
"cell_type": "markdown",
"id": "45470edd",
"metadata": {},
"source": [
"## Using EasyGuide class in custom loop"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a086b33f",
"metadata": {},
"outputs": [],
"source": [
"class RegressionGuideAsClass(EasyGuide):\n",
" def __init__(self, model):\n",
" super().__init__(model)\n",
"\n",
" def guide(self, x, full_size, y=None):\n",
" group = self.group(match=\".*\")\n",
" loc = pyro.param(\"loc\", torch.randn(group.event_shape))\n",
" scale = pyro.param(\"scale\", torch.ones(group.event_shape)*0.01, constraint=constraints.positive)\n",
" group.sample(\"joint\", dist.Normal(loc=loc, scale=scale).to_event(1))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f50d0bf3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [00:53<00:00, 93.05it/s]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"base_regression_model = BayesianRegression(in_features, out_features)\n",
"regression_guide_as_class = RegressionGuideAsClass(base_regression_model)\n",
"train(base_regression_model, regression_guide_as_class, x_data, y_data, dataset, adam_params)"
]
},
{
"cell_type": "markdown",
"id": "5a6a737e",
"metadata": {},
"source": [
"## Using easy_guide decorator in custom loop"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2891c910",
"metadata": {},
"outputs": [],
"source": [
"@easy_guide(base_regression_model)\n",
"def regression_guide_with_decorator(self, x, full_size, y=None):\n",
" group = self.group(match=\".*\")\n",
" loc = pyro.param(\"loc\", torch.randn(group.event_shape))\n",
" scale = pyro.param(\"scale\", torch.ones(group.event_shape)*0.01, constraint=constraints.positive)\n",
" group.sample(\"joint\", dist.Normal(loc=loc, scale=scale).to_event(1))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "df1fa914",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [00:54<00:00, 91.90it/s]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"base_regression_model = BayesianRegression(in_features, out_features)\n",
"train(base_regression_model, regression_guide_with_decorator, x_data, y_data, dataset, adam_params)"
]
},
{
"cell_type": "markdown",
"id": "ba1b13b6",
"metadata": {},
"source": [
"## Using AutoNormal"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ae4845aa",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [01:12<00:00, 69.17it/s]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"base_regression_model = BayesianRegression(in_features, out_features)\n",
"auto_guide = pyro.infer.autoguide.AutoNormal(base_regression_model)\n",
"train(base_regression_model, auto_guide, x_data, y_data, dataset, adam_params)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment