Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active April 11, 2024 03:06
Show Gist options
  • Save ricardoV94/1bbb1d44f491b4917d63c005b8ecab78 to your computer and use it in GitHub Desktop.
Save ricardoV94/1bbb1d44f491b4917d63c005b8ecab78 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "1hBg4YmsqvSz"
},
"source": [
"# PyMC Optimization workflow\n",
"\n",
"Notebook create for the following discourse thread: https://discourse.pymc.io/t/pymc-for-bayesian-optimization/11293"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "xhmWGgXSe5aC"
},
"outputs": [],
"source": [
"import pymc as pm\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"from pytensor.graph.rewriting.utils import rewrite_graph\n",
"\n",
"import numpy as np\n",
"from scipy.optimize import minimize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "s0M4UuGngCLg"
},
"outputs": [],
"source": [
"seed = sum(map(ord, \"PyMC Optimization\"))\n",
"rng = np.random.default_rng(seed)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "ARMR6-3De97W"
},
"outputs": [],
"source": [
"with pm.Model() as generative_model:\n",
" x = pm.MutableData(\"x\", np.random.normal(size=(100, 5)), dims=[\"batch\", \"features\"])\n",
" betas = pm.Normal(\"betas\", shape=5, dims=\"features\")\n",
"\n",
" mu = pm.Deterministic(\"mu\", x @ betas, dims=[\"batch\"])\n",
" sigma = pm.HalfNormal(\"sigma\")\n",
" y = pm.Normal(\"y\", mu, sigma, shape=mu.shape, dims=[\"batch\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "blf77lysfpmr"
},
"outputs": [],
"source": [
"# Simulate data for specific parameter values\n",
"fixed_parameters = {\n",
" \"betas\": [-2, -1, 0, 1, 2],\n",
" \"sigma\": 0.5,\n",
"}\n",
"with pm.do(generative_model, fixed_parameters) as synthetic_model:\n",
" synthetic_y = pm.draw(synthetic_model[\"y\"], random_seed=rng)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 332
},
"id": "gfKm-lwpgQtw",
"outputId": "91261faa-9f44-4290-b636-e73ca8c9ada0"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [betas, sigma]\n"
]
},
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [8000/8000 00:04&lt;00:00 Sampling 4 chains, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4 seconds.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>betas[0]</th>\n",
" <td>-1.997</td>\n",
" <td>0.054</td>\n",
" <td>-2.095</td>\n",
" <td>-1.890</td>\n",
" <td>0.001</td>\n",
" <td>0.001</td>\n",
" <td>4625.0</td>\n",
" <td>3487.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>betas[1]</th>\n",
" <td>-0.990</td>\n",
" <td>0.064</td>\n",
" <td>-1.108</td>\n",
" <td>-0.870</td>\n",
" <td>0.001</td>\n",
" <td>0.001</td>\n",
" <td>6048.0</td>\n",
" <td>3008.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>betas[2]</th>\n",
" <td>0.038</td>\n",
" <td>0.060</td>\n",
" <td>-0.071</td>\n",
" <td>0.156</td>\n",
" <td>0.001</td>\n",
" <td>0.001</td>\n",
" <td>6143.0</td>\n",
" <td>3044.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>betas[3]</th>\n",
" <td>0.958</td>\n",
" <td>0.052</td>\n",
" <td>0.857</td>\n",
" <td>1.054</td>\n",
" <td>0.001</td>\n",
" <td>0.001</td>\n",
" <td>4936.0</td>\n",
" <td>3266.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>betas[4]</th>\n",
" <td>2.022</td>\n",
" <td>0.055</td>\n",
" <td>1.925</td>\n",
" <td>2.131</td>\n",
" <td>0.001</td>\n",
" <td>0.000</td>\n",
" <td>6202.0</td>\n",
" <td>3688.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>sigma</th>\n",
" <td>0.536</td>\n",
" <td>0.039</td>\n",
" <td>0.465</td>\n",
" <td>0.609</td>\n",
" <td>0.001</td>\n",
" <td>0.000</td>\n",
" <td>5813.0</td>\n",
" <td>3339.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n",
"betas[0] -1.997 0.054 -2.095 -1.890 0.001 0.001 4625.0 \n",
"betas[1] -0.990 0.064 -1.108 -0.870 0.001 0.001 6048.0 \n",
"betas[2] 0.038 0.060 -0.071 0.156 0.001 0.001 6143.0 \n",
"betas[3] 0.958 0.052 0.857 1.054 0.001 0.001 4936.0 \n",
"betas[4] 2.022 0.055 1.925 2.131 0.001 0.000 6202.0 \n",
"sigma 0.536 0.039 0.465 0.609 0.001 0.000 5813.0 \n",
"\n",
" ess_tail r_hat \n",
"betas[0] 3487.0 1.0 \n",
"betas[1] 3008.0 1.0 \n",
"betas[2] 3044.0 1.0 \n",
"betas[3] 3266.0 1.0 \n",
"betas[4] 3688.0 1.0 \n",
"sigma 3339.0 1.0 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Infer parameters\n",
"with pm.observe(generative_model, {\"y\": synthetic_y}) as inference_model:\n",
" idata = pm.sample(random_seed=rng)\n",
"\n",
"pm.stats.summary(idata, var_names=[\"betas\", \"sigma\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "fwA9BXCfgc4d"
},
"outputs": [],
"source": [
"from typing import Sequence\n",
"from pymc import Model\n",
"from arviz import InferenceData\n",
"from pytensor.tensor import TensorVariable\n",
"\n",
"def rewrite(*vars: TensorVariable) -> Sequence[TensorVariable]:\n",
" return rewrite_graph(vars, include=(\"ShapeOpt\", \"canonicalize\", \"stabilize\", \"specialize\"))\n",
"\n",
"def posterior_predictive_fn(\n",
" model: Model,\n",
" idata: InferenceData,\n",
" var_names: Sequence[str],\n",
" replace: dict[TensorVariable, TensorVariable],\n",
") -> Sequence[TensorVariable]:\n",
"\n",
" # Replace every unobserved RV by a dummy with the same shape (excluding chain, draws)\n",
" dummy_posterior_vars = {\n",
" model_rv: pt.tensor(model_rv.name, shape=idata.posterior[model_rv.name].shape[2:])\n",
" for model_rv in model.free_RVs\n",
" if model_rv.name not in var_names\n",
" }\n",
"\n",
" predicted_vars = pytensor.clone_replace(\n",
" [model[var_name] for var_name in var_names],\n",
" replace=dummy_posterior_vars | replace,\n",
" rebuild_strict=False,\n",
" )\n",
"\n",
" # Clean up shape graph (or else blockwise fails)\n",
" predicted_vars = rewrite(*predicted_vars)\n",
"\n",
" # Replace dummy_vars by posterior draws.\n",
" # Vectorize graph handles the new batch dims\n",
" posterior_vars = {\n",
" var: pt.constant(idata.posterior[var.name], name=var.name)\n",
" for var in dummy_posterior_vars.values()\n",
" }\n",
" batch_predicted_vars = pytensor.graph.replace.vectorize_graph(\n",
" predicted_vars,\n",
" replace=posterior_vars\n",
" )\n",
"\n",
" return batch_predicted_vars"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "-TW5EsJ3TlP5",
"outputId": "58b44ca5-3e8c-4825-9472-5b1ecad9613d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DropDims{axes=[2, 3]} [id A] <Matrix(float64, shape=(4, 1000))>\n",
" └─ SpecifyShape [id B] <Tensor4(float64, shape=(4, 1000, 1, 1))>\n",
" ├─ Transpose{axes=[1, 2, 0, 3]} [id C] <Tensor4(float64, shape=(?, ?, 1, 1))>\n",
" │ └─ Reshape{4} [id D] <Tensor4(float64, shape=(1, ?, ?, 1))>\n",
" │ ├─ dot [id E] <Matrix(float64, shape=(1, 4000))>\n",
" │ │ ├─ ExpandDims{axis=0} [id F] <Matrix(float64, shape=(1, 5))>\n",
" │ │ │ └─ x [id G] <Vector(float64, shape=(5,))>\n",
" │ │ └─ [[-2.00337 ... 6598e+00]] [id H] <Matrix(float64, shape=(5, 4000))>\n",
" │ └─ [ 1 4 ... 1000 1] [id I] <Vector(int64, shape=(4,))>\n",
" ├─ 4 [id J] <Scalar(int8, shape=())>\n",
" ├─ 1000 [id K] <Scalar(int16, shape=())>\n",
" ├─ 1 [id L] <Scalar(int8, shape=())>\n",
" └─ 1 [id M] <Scalar(int8, shape=())>\n",
"DropDims{axis=2} [id N] <Matrix(float64, shape=(4, 1000))>\n",
" └─ normal_rv{0, (0, 0), floatX, False}.1 [id O] <Tensor3(float64, shape=(4, 1000, 1))>\n",
" ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F2FEBC302E0>) [id P] <RandomGeneratorType>\n",
" ├─ [ 4 1000 1] [id Q] <Vector(int64, shape=(3,))>\n",
" ├─ 11 [id R] <Scalar(int64, shape=())>\n",
" ├─ DropDims{axis=3} [id S] <Tensor3(float64, shape=(4, 1000, 1))>\n",
" │ └─ SpecifyShape [id B] <Tensor4(float64, shape=(4, 1000, 1, 1))>\n",
" │ └─ ···\n",
" └─ [[[0.51133 ... 4335869]]] [id T] <Tensor3(float64, shape=(4, 1000, 1))>\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f30760e1840>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create graphs from x to (y, mu), for every draw from the posterior\n",
"x = pt.tensor(\"x\", shape=(5,))\n",
"\n",
"# Add the batch dim\n",
"replace_x = x[None]\n",
"# replace_x = pt.broadcast_to(x, idata.constant_data[\"x\"].shape)\n",
"\n",
"predicted_y, predicted_mu = posterior_predictive_fn(\n",
" model=generative_model,\n",
" idata=idata,\n",
" var_names=[\"y\", \"mu\"], \n",
" replace={generative_model[\"x\"]: replace_x},\n",
")\n",
"\n",
"# Drop the batch dim\n",
"predicted_y, predicted_mu = predicted_y.squeeze(-1), predicted_mu.squeeze(-1)\n",
"# predicted_y, predicted_mu = predicted_y[..., 0], predicted_mu[..., 0]\n",
"\n",
"predicted_y, predicted_mu = rewrite(predicted_y, predicted_mu)\n",
"pytensor.dprint([predicted_mu, predicted_y], print_type=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Hco5BGG3UjFD",
"outputId": "de1e46d0-e96c-45e3-fbb6-aebb04f9c490"
},
"outputs": [
{
"data": {
"text/plain": [
"(2.5119952383345927, (4, 1000))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# For debugging, we check the graph looks correct and that evaluating returns what we expect\n",
"x_test = np.array([0, 0.25, 0.5, 0.75, 1.0])\n",
"res = predicted_mu.eval({x: x_test})\n",
"res.mean(), res.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XqoDLzlYjVV_",
"outputId": "6fdaa625-fec2-4ca3-a443-9d9e9e1b9433"
},
"outputs": [
{
"data": {
"text/plain": [
"array(24786.39509671)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a cost function\n",
"# target_y could be a SharedVariable so we can optimize for different values\n",
"# without having to recompile the cost and grad functions\n",
"target_y = 5.0\n",
"\n",
"cost = pt.sum((target_y - predicted_mu) ** 2)\n",
"cost_fn = pytensor.function([x], cost)\n",
"cost_fn(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8inHy0kZkyiS",
"outputId": "0df87060-9f1b-4c9e-db61-b333ecd7ae49"
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 39743.84888322, 19721.50155962, -740.94660199, -19045.16624961,\n",
" -40223.42767489])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Can also get gradient function for scipy\n",
"grad_fn = pytensor.function([x], pt.grad(cost, wrt=x))\n",
"grad_fn(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lR82IxHukE9q",
"outputId": "75ef9aaa-03a1-43ff-9cef-65c6b1fb2c65"
},
"outputs": [
{
"data": {
"text/plain": [
" fun: 27.73603988253558\n",
" hess_inv: array([[ 0.02245723, -0.00485788, -0.0006937 , 0.00593183, 0.01698518],\n",
" [-0.00485788, 0.02792761, -0.00248036, 0.00198485, 0.00798208],\n",
" [-0.0006937 , -0.00248036, 0.03571901, 0.00078844, -0.00294736],\n",
" [ 0.00593183, 0.00198485, 0.00078844, 0.0409628 , -0.01258027],\n",
" [ 0.01698518, 0.00798208, -0.00294736, -0.01258027, 0.02670575]])\n",
" jac: array([-2.58878825e-06, -1.25715929e-06, -8.10551371e-08, 1.27604829e-06,\n",
" 2.66765740e-06])\n",
" message: 'Optimization terminated successfully.'\n",
" nfev: 14\n",
" nit: 10\n",
" njev: 14\n",
" status: 0\n",
" success: True\n",
" x: array([-1.07362961, -0.42703831, -0.07130571, 0.55570669, 0.94079541])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Optimize x\n",
"res = minimize(\n",
" fun=cost_fn,\n",
" x0=x_test,\n",
" jac=grad_fn,\n",
")\n",
"res"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"posterior_pred_fn = pm.pytensorf.compile_pymc([x], predicted_y, random_seed=5)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pm.plots.plot_posterior(\n",
" posterior_pred_fn(res.x),\n",
" ref_val=target_y,\n",
");"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RQ87MQE2o13F"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "pymc",
"language": "python",
"name": "pymc"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment