Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created June 9, 2023 14:44
Show Gist options
  • Save ricardoV94/1038c2c45a9acfd081654a2e64e757b4 to your computer and use it in GitHub Desktop.
Save ricardoV94/1038c2c45a9acfd081654a2e64e757b4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "1d614fbd",
"metadata": {},
"source": [
"# Automatic probabiltity\n",
"\n",
"Slides for the accompanying talk: [link](https://docs.google.com/presentation/d/1xLvEOdnEC2nZ0jqBSbssP2vWsiNHHmMX-kprhuNH5Cg/edit?usp=sharing)\n",
"\n",
"Source code: [link](https://github.com/pymc-devs/pymc/tree/main/pymc/logprob)\n",
"\n",
"Versioned source code: [link](https://github.com/pymc-devs/pymc/tree/2ac88afa4212dcbeaf9471c6f54c7f50b5a3db53/pymc/logprob)\n",
"\n",
"Previous related talk: [link](https://www.youtube.com/watch?v=_APNiXTfYJw)\n",
"\n",
"Relevant documentation:\n",
"* PyMC and PyTensor: [link](https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_pytensor.html)\n",
"* CustomDist: [link](https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.CustomDist.html)\n",
"\n",
"Compatible PyMC version: 5.5.0"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0fc1322a",
"metadata": {},
"outputs": [],
"source": [
"import pymc as pm\n",
"import pytensor\n",
"import numpy as np\n",
"\n",
"import pytensor.tensor as pt\n",
"\n",
"pytensor.config.mode = \"FAST_COMPILE\""
]
},
{
"cell_type": "markdown",
"id": "ae0cff50",
"metadata": {},
"source": [
"# Single random variable transformation"
]
},
{
"cell_type": "markdown",
"id": "db7aede3",
"metadata": {},
"source": [
"## 1-to-1 transformations"
]
},
{
"cell_type": "markdown",
"id": "a2388f43",
"metadata": {},
"source": [
"### Linear shift"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1c797a5d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array(-1.10557431), array(3.89442569))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pm.Normal.dist()\n",
"z = x + 5\n",
"\n",
"x_draw, z_draw = pm.draw([x, z])\n",
"x_draw, z_draw"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3a2f9cbf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(0.21651709)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def prob(rv, value):\n",
" return pm.logp(rv, value).exp()\n",
"\n",
"prob(x, x_draw).eval()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ad1c43a2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(0.21651709)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob(z, z_draw).eval()"
]
},
{
"cell_type": "markdown",
"id": "5ade8be6",
"metadata": {},
"source": [
"### Exponentiation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2b13f751",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array(-0.28362562), array(0.75304852))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pm.Normal.dist()\n",
"z = pt.exp(x)\n",
"\n",
"x_draw, z_draw = pm.draw([x, z])\n",
"x_draw, z_draw"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1b5b6c9b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(0.38321454)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob(x, x_draw).eval()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "13ef22b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(0.50888427)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob(z, z_draw).eval()"
]
},
{
"cell_type": "markdown",
"id": "834f9841",
"metadata": {},
"source": [
"### Inverse"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a6b60ffd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([3.96040008e-09, 1.41976891e-01])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pm.Normal.dist(shape=(2,))\n",
"z = 1 / x\n",
"\n",
"value = np.array([-0.15, 1.5])\n",
"prob(z, value).eval()"
]
},
{
"cell_type": "markdown",
"id": "97f1f260",
"metadata": {},
"source": [
"### Others"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "49bcac15",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.24197072451914337\n",
"0.48394144903828673\n",
"0.07820853879509117\n",
"0.026958231758816044\n",
"0.39614036832836785\n"
]
}
],
"source": [
"x = pm.Normal.dist()\n",
"\n",
"print(\n",
" prob(x ** 2, 1).eval(),\n",
" prob(pt.sqrt(x), 1).eval(),\n",
" prob(x * 5, 1).eval(),\n",
" prob(pt.log(x), 1).eval(),\n",
" prob(pt.erf(x), 0.5).eval(),\n",
" sep=\"\\n\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e7f3c223",
"metadata": {},
"source": [
"## Many-to-1 transformations"
]
},
{
"cell_type": "markdown",
"id": "105c8bd0",
"metadata": {},
"source": [
"### Abs"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "39a7a3c3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pm.Normal.dist(0, 1)\n",
"z = pt.abs(x)\n",
"\n",
"np.isclose(\n",
" prob(x, 0.3).eval(),\n",
" (prob(z, 0.3) / 2).eval(),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "577cbfad",
"metadata": {},
"source": [
"### Clipping"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "03cb9b86",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x=-0.84 → z=-0.84 \n",
"x=-1.16 → z=-1.00 (clipped)\n",
"x=-0.76 → z=-0.76 \n",
"x=0.61 → z=0.61 \n",
"x=2.07 → z=2.07 \n",
"x=1.24 → z=1.24 \n",
"x=0.90 → z=0.90 \n",
"x=-0.89 → z=-0.89 \n",
"x=-0.14 → z=-0.14 \n",
"x=-0.78 → z=-0.78 \n"
]
}
],
"source": [
"x = pm.Normal.dist()\n",
"z = pt.clip(x, -1, np.inf)\n",
"\n",
"draws = pm.draw([x, z], draws=10)\n",
"\n",
"for x_draw, z_draw in zip(*draws):\n",
" print(\n",
" f\"x={x_draw: <5.2f} → z={z_draw:.2f}\"\n",
" f\" {'(clipped)' if x_draw < -1 else ''}\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e6a7ec97",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.15865525 0.35206533 0.02275013]\n",
"[0.15865525 0.35206533 0.02275013]\n"
]
}
],
"source": [
"rv = pm.Normal.dist(shape=(3,))\n",
"\n",
"clipped_rv = pt.clip(rv, -1, 2)\n",
"censored_rv = pm.Censored.dist(rv, lower=-1, upper=2)\n",
"\n",
"clipped_value = [-1, 0.5, 2]\n",
"print(\n",
" prob(clipped_rv, clipped_value).eval(),\n",
" prob(censored_rv, clipped_value).eval(),\n",
" sep=\"\\n\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "4ab7ba28",
"metadata": {},
"source": [
"### Others"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d7b7543e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.24197072451914337\n",
"0.24173033519680465\n",
"0.3413447502096956\n",
"0.1359051181327496\n"
]
}
],
"source": [
"rv = pm.Normal.dist(1)\n",
"\n",
"print(\n",
" prob(rv, 0).eval(),\n",
" prob(pt.round(rv), 0).eval(),\n",
" prob(pt.floor(rv), 0).eval(),\n",
" prob(pt.ceil(rv), 0).eval(),\n",
" sep=\"\\n\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "6b8d9b78",
"metadata": {},
"source": [
"## Chained transformations"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6b549a5f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(0.19394715)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pm.Laplace.dist(0, 1)\n",
"y = x + 2\n",
"z = pt.abs(y)\n",
"\n",
"prob(z, 0.9).eval()"
]
},
{
"cell_type": "markdown",
"id": "edb0473c",
"metadata": {},
"source": [
"# Multiple random variable transformations"
]
},
{
"cell_type": "markdown",
"id": "51444a8a",
"metadata": {},
"source": [
"## Conditional probability"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "67f828c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(1.30169683), array(0.03433906), array(1.26735777)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pm.Normal.dist()\n",
"y = pm.Normal.dist()\n",
"z = x - y\n",
"\n",
"pm.draw([x, y, z])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "bd697b50",
"metadata": {},
"outputs": [],
"source": [
"from pymc.logprob.basic import conditional_logp\n",
"\n",
"def conditional_prob(rvs_to_values: dict, fn=True):\n",
" logps = conditional_logp(rvs_to_values).values()\n",
" probs = [logp.exp() for logp in logps]\n",
" if fn:\n",
" values = list(rvs_to_values.values())\n",
" return pytensor.function(values, probs)\n",
" else:\n",
" return probs\n",
" \n",
"x_value = pt.scalar(\"x_value\")\n",
"y_value = pt.scalar(\"y_value\")\n",
"z_value = pt.scalar(\"z_value\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e0076461",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(0.35206533), array(0.26608525)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob_fn = conditional_prob({x: x_value, y: y_value})\n",
"prob_fn(x_value=0.5, y_value=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "c4fd46ac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(0.35206533), array(0.26608525)]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob_fn = conditional_prob({x: x_value, z: z_value})\n",
"prob_fn(x_value=0.5, z_value=0.5+0.9)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "92794504",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(0.26608525), array(0.35206533)]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob_fn = conditional_prob({y: y_value, z: z_value})\n",
"prob_fn(y_value=0.9, z_value=0.5-0.9)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "8511c691",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RuntimeError: The logprob terms of the following value variables could not be derived: {z_value}\n"
]
}
],
"source": [
"try:\n",
" prob_fn = conditional_prob({z: z_value})\n",
"except RuntimeError as err:\n",
" print(f\"RuntimeError: {err}\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7f42dd87",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RuntimeError: The logprob terms of the following value variables could not be derived: {z_value}\n"
]
}
],
"source": [
"try:\n",
" prob_fn = conditional_prob({x: x_value, y: y_value, z: z_value})\n",
"except RuntimeError as err:\n",
" print(f\"RuntimeError: {err}\")"
]
},
{
"cell_type": "markdown",
"id": "ad65290b",
"metadata": {},
"source": [
"**Note:** Once we condition other variables, the last example is just a single variable logp expression"
]
},
{
"cell_type": "markdown",
"id": "f0d9c89a",
"metadata": {},
"source": [
"## Control flow"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "3c3f7f3c",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"switch=0 → x=0.91\n",
"switch=0 → x=1.05\n",
"switch=1 → x=-0.45\n",
"switch=0 → x=1.96\n",
"switch=1 → x=-2.03\n"
]
}
],
"source": [
"from pytensor.ifelse import ifelse\n",
"\n",
"switch = pm.Bernoulli.dist(p=0.7)\n",
"x1 = pm.Normal.dist(-1)\n",
"x2 = pm.Laplace.dist(1, 1)\n",
"x = ifelse(switch, x1, x2)\n",
"\n",
"for switch_draw, x_draw in zip(*pm.draw([switch, x], draws=5)):\n",
" print(f\"switch={switch_draw} → x={x_draw:.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "2584d27f",
"metadata": {},
"outputs": [],
"source": [
"switch_value = pt.scalar(\"switch_value\", dtype=int)\n",
"x_value = pt.scalar(\"x_value\", dtype=float)\n",
"prob_fn = conditional_prob({switch: switch_value, x: x_value})"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "481da31e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(0.7), array(0.39695255)]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob_fn(switch_value=1, x_value=-0.9)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "3855115b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(0.3), array(0.07478431)]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob_fn(switch_value=0, x_value=-0.9)"
]
},
{
"cell_type": "markdown",
"id": "28a2d551",
"metadata": {},
"source": [
"## Stacking operations"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "f29f8e0f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([1., 1., 1.])]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x0 = pm.Uniform.dist(0, 1)\n",
"x1 = pm.Uniform.dist(x0, x0+1)\n",
"x2 = pm.Uniform.dist(x1, x1+1)\n",
"xs = pt.stack([x0, x1, x2])\n",
"\n",
"xs_values = pt.vector(\"xs\", shape=(3,))\n",
"conditional_prob_fn = conditional_prob({xs: xs_values})\n",
"\n",
"conditional_prob_fn([0.5, 1.5, 2.5])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "1256d2bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([1., 1., 0.])]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conditional_prob_fn([0.5, 1.5, 0.5])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "a52245d8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.03677558, 1.32534683, 1.25944947, 0.96523192, 2.29075459])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mu = 0\n",
"sigma = 1\n",
"\n",
"y0 = np.array(0.5)\n",
"ys = []\n",
"y_tm1 = y0\n",
"for i in range(5):\n",
" y = y_tm1 + pm.Normal.dist(mu, sigma)\n",
" ys = pt.concatenate([ys, [y]])\n",
" y_tm1 = y\n",
" \n",
"pm.draw(ys)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "48146b70",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"62"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pytensor.graph.basic import ancestors\n",
"len(list(ancestors(ys)))"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "c09289b5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(1.43436108),\n",
" array(0.48492827),\n",
" array([1.79509336, 2.27295269, 3.73986879, 5.14309914, 6.35157192])]"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pymc.pytensorf import collect_default_updates\n",
"\n",
"mu = pm.Normal.dist()\n",
"sigma = pm.HalfNormal.dist()\n",
"y0 = np.array(0.5)\n",
"\n",
"def rw_step(*args):\n",
" y_tm1, mu, sigma = args\n",
" y = y_tm1 + pm.Normal.dist(mu=mu, sigma=sigma)\n",
" return y, collect_default_updates(inputs=args, outputs=[y])\n",
"\n",
"ys, _ = pytensor.scan(\n",
" fn=rw_step,\n",
" outputs_info=[y0],\n",
" non_sequences=[mu, sigma],\n",
" n_steps=5,\n",
")\n",
"\n",
"pm.draw([mu, sigma, ys])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "906c79bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array(0.24197072),\n",
" array(0.10798193),\n",
" array([0.19333406, 0.19947114, 0.19947114, 0.00043634, 0.19947114])]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mu_value = pt.scalar(\"mu_value\")\n",
"sigma_value = pt.scalar(\"sigma_value\")\n",
"ys_values = pt.vector(\"ys_values\")\n",
"\n",
"prob_fn = conditional_prob({mu: mu_value, sigma: sigma_value, ys: ys_values})\n",
"\n",
"prob_fn(mu_value=1, sigma_value=2, ys_values=[1, 2, 3, -3, -2])"
]
},
{
"cell_type": "markdown",
"id": "2b7d577e",
"metadata": {},
"source": [
"# Extras"
]
},
{
"cell_type": "markdown",
"id": "3ab72c17",
"metadata": {},
"source": [
"## CustomDist"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "5ace2829",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-1.93039708, 4.40054909, -3.67255424, 2.06813603, -4.47753441])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def dist1(loc, lam, size):\n",
" return loc + pm.Exponential.dist(lam, shape=size)\n",
" \n",
"def dist2(alpha, beta, lower, upper, size):\n",
" range_ = upper - lower\n",
" return pm.Beta.dist(alpha, beta, shape=size) * (range_) + lower\n",
" \n",
"comp1 = pm.CustomDist.dist(-5, 1, dist=dist1)\n",
"comp2 = pm.CustomDist.dist(1, 1, -5, 5, dist=dist2)\n",
"mix = pm.Mixture.dist([0.3, 0.7], comp_dists=[comp1, comp2], shape=(5,))\n",
"pm.draw(mix)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "5780cc39",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(0.00408677)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob(comp1, 0.5).eval()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "c89059c1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.08493612, 0.07549469, 0.07202138, 0.07074363, 0.07027356])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob(mix, [-2, -1, 0, 1, 2]).eval()"
]
},
{
"cell_type": "markdown",
"id": "2a23c513",
"metadata": {},
"source": [
"## Indexing"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "2aeb18f3",
"metadata": {},
"outputs": [],
"source": [
"# Monty Hall\n",
"from pytensor.ifelse import ifelse\n",
"\n",
"first_pick = pt.scalar(\"first_pick\", dtype=int)\n",
"correct = pm.Categorical.dist([1/3, 1/3, 1/3])\n",
"opened_door = ( \n",
" ifelse(\n",
" pt.eq(first_pick, 0),\n",
" pt.stack([\n",
" pm.Categorical.dist([0, 1/2, 1/2]), # correct = 0\n",
" pm.Categorical.dist([0, 0, 1]), # correct = 1\n",
" pm.Categorical.dist([0, 1, 0]), # correct = 2\n",
" ])[correct],\n",
" # else(first pick != 0)\n",
" ifelse( \n",
" pt.eq(first_pick, 1),\n",
" pt.stack([\n",
" pm.Categorical.dist([0, 0, 1]), # correct = 0\n",
" pm.Categorical.dist([1/2, 0, 1/2]), # correct = 1\n",
" pm.Categorical.dist([1, 0, 0]), # correct = 2\n",
" ])[correct],\n",
" # else (first_pick == 2)\n",
" pt.stack([\n",
" pm.Categorical.dist([0, 1, 0]), # correct = 0\n",
" pm.Categorical.dist([1, 0, 0]), # correct = 1\n",
" pm.Categorical.dist([1/2, 1/2, 0]), # correct = 2\n",
" ])[correct], \n",
" ), \n",
" )\n",
")\n",
"\n",
"correct_value = pt.scalar(\"correct\", dtype=int)\n",
"opened_door_value = pt.scalar(\"opened_door\", dtype=int)\n",
"\n",
"prob_correct, prob_opened_door = conditional_prob({correct: correct_value, opened_door: opened_door_value}, fn=False)\n",
"total_prob = prob_correct * prob_opened_door"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "0ba756ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.16666666666666666\n",
"0.0\n",
"0.3333333333333333\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ricardo/miniconda3/envs/pymc/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:779: RuntimeWarning: divide by zero encountered in log\n",
" variables = ufunc(*ufunc_args, **ufunc_kwargs)\n"
]
}
],
"source": [
"for c in (0, 1, 2):\n",
" print(total_prob.eval({first_pick: 0, opened_door_value: 1, correct_value:c}))"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "pymc",
"language": "python",
"name": "pymc"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment