Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save aflaxman/cb43e6b876ec78cb7ec4 to your computer and use it in GitHub Desktop.
Save aflaxman/cb43e6b876ec78cb7ec4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:0d5c95c7ee7305f3517824521e605e051812eb8f3d713fb76fe2ce71bf4e10ff"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np, pymc as pm, pandas"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# set random seed for reproducibility\n",
"np.random.seed(12345)\n",
"\n",
"# Generate the synthetic data\n",
"a = 2.0 \n",
"b = 8.0\n",
"c = 6.0\n",
"d = c + (b-a)\n",
"d1 = np.random.uniform(a, b, 100) \n",
"d2 = np.random.uniform(c, d, 100) \n",
"data = d1 + d2"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"#Setup the priors\n",
"\n",
"pa = pm.Normal(\"pa\", 0.0, .01, value=0)\n",
"pb = pm.Normal(\"pb\", 0.0, .01, value=10)\n",
"pc = pm.Normal(\"pc\", 0.0, .01, value=0)\n",
"pd = pm.Normal(\"pd\", 0.0, .01, value=10)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# if data was just d1, this would be more familiar\n",
"data = d1\n",
"\n",
"# here is how you could fit it with an \"observed\" stochastic\n",
"@pm.observed\n",
"def pdata(value=data, pa=pa, pb=pb, pc=pc, pd=pd):\n",
" return pm.uniform_like(value, pa, pb)\n",
"\n",
"m = pm.MCMC(dict(pa=pa, pb=pb, pc=pc, pd=pd, pdata=pdata))\n",
"m.sample(100000, 50000, 50)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 3159 of 100000 complete in 0.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-- 6% ] 6297 of 100000 complete in 1.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [--- 9% ] 9445 of 100000 complete in 1.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [---- 12% ] 12604 of 100000 complete in 2.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [----- 15% ] 15769 of 100000 complete in 2.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------- 18% ] 18928 of 100000 complete in 3.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-------- 22% ] 22093 of 100000 complete in 3.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [--------- 25% ] 25258 of 100000 complete in 4.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [---------- 28% ] 28380 of 100000 complete in 4.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [----------- 31% ] 31509 of 100000 complete in 5.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------------- 34% ] 34652 of 100000 complete in 5.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-------------- 37% ] 37798 of 100000 complete in 6.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [--------------- 40% ] 40943 of 100000 complete in 6.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [---------------- 44% ] 44074 of 100000 complete in 7.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------47% ] 47182 of 100000 complete in 7.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------50% ] 50327 of 100000 complete in 8.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------53% ] 53423 of 100000 complete in 8.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------56%- ] 56516 of 100000 complete in 9.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------59%-- ] 59561 of 100000 complete in 9.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------62%--- ] 62629 of 100000 complete in 10.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------65%---- ] 65704 of 100000 complete in 10.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------68%------ ] 68783 of 100000 complete in 11.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------71%------- ] 71882 of 100000 complete in 11.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------74%-------- ] 74975 of 100000 complete in 12.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------78%--------- ] 78066 of 100000 complete in 12.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------81%---------- ] 81153 of 100000 complete in 13.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------84%------------ ] 84261 of 100000 complete in 13.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------87%------------- ] 87335 of 100000 complete in 14.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------90%-------------- ] 90420 of 100000 complete in 14.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------93%--------------- ] 93452 of 100000 complete in 15.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------96%---------------- ] 96555 of 100000 complete in 15.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------99%----------------- ] 99632 of 100000 complete in 16.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------100%-----------------] 100000 of 100000 complete in 16.1 sec"
]
}
],
"prompt_number": 4
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"print pandas.DataFrame(m.stats()).loc[['mean', 'standard deviation']]"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
" pa pb pc pd\n",
"mean 1.992644 8.023612 -0.3094084 0.09822238\n",
"standard deviation 0.0577922 0.05885509 10.1256 10.20577\n"
]
}
],
"prompt_number": 5
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"#Setup the priors\n",
"\n",
"pa = pm.Normal(\"pa\", 0.0, .01, value=0)\n",
"pb = pm.Normal(\"pb\", 0.0, .01, value=10)\n",
"pc = pm.Normal(\"pc\", 0.0, .01, value=0)\n",
"pd = pm.Normal(\"pd\", 0.0, .01, value=10)\n",
"\n",
"# In general, you can use this pattern with a more complicated\n",
"# calculation of the log-likelihood, e.g.\n",
"\n",
"@pm.observed\n",
"def pdata(value=data, pa=pa, pb=pb, pc=pc, pd=pd):\n",
" logpr = 0.\n",
" # code to calculate logpr from data and parameters\n",
" return logpr\n",
"\n",
"\n",
"# Use MCMC to sample and obtain traces\n",
"m = pm.MCMC(dict(pa=pa, pb=pb, pc=pc, pd=pd, pdata=pdata))\n",
"m.sample(100000, 50000, 50)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [- 3% ] 3943 of 100000 complete in 0.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [--- 7% ] 7910 of 100000 complete in 1.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [---- 11% ] 11870 of 100000 complete in 1.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------ 15% ] 15811 of 100000 complete in 2.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------- 19% ] 19724 of 100000 complete in 2.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-------- 23% ] 23673 of 100000 complete in 3.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [---------- 27% ] 27650 of 100000 complete in 3.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------------ 31% ] 31594 of 100000 complete in 4.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------------- 35% ] 35525 of 100000 complete in 4.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-------------- 39% ] 39443 of 100000 complete in 5.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [---------------- 43% ] 43413 of 100000 complete in 5.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------47% ] 47386 of 100000 complete in 6.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------51% ] 51372 of 100000 complete in 6.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------55%- ] 55334 of 100000 complete in 7.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------59%-- ] 59276 of 100000 complete in 7.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------63%---- ] 63174 of 100000 complete in 8.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------67%----- ] 67071 of 100000 complete in 8.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------71%------ ] 71019 of 100000 complete in 9.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------74%-------- ] 74963 of 100000 complete in 9.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------78%--------- ] 78895 of 100000 complete in 10.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------82%----------- ] 82802 of 100000 complete in 10.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------86%------------ ] 86732 of 100000 complete in 11.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------90%-------------- ] 90651 of 100000 complete in 11.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------94%--------------- ] 94571 of 100000 complete in 12.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------98%----------------- ] 98495 of 100000 complete in 12.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------100%-----------------] 100000 of 100000 complete in 12.7 sec"
]
}
],
"prompt_number": 6
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To keep this example simple, let us restrict the model to have $c-a = b-d$. Then $d1_i+d2_i$ has a symmetric triangular distribution, with support $[a+c, b+d]$. In other words,\n",
"$$\n",
"p(data_i=x|a,b,c,d) \\propto \\begin{cases}\n",
"x - (a+c) &\\quad \\text{if } a+c \\leq x \\leq \\frac{a+b+c+d}{2};\\\\\n",
"(b+d) - x &\\quad \\text{if } \\frac{a+b+c+d}{2} \\leq x \\leq b+d;\\\\\n",
"0 &\\quad {\\text{otherwise.}}\n",
"\\end{cases}\n",
"$$\n",
"\n",
"Did I say \"simple\"?"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# set random seed for reproducibility\n",
"np.random.seed(12345)\n",
"\n",
"# Generate the synthetic data\n",
"a = 2.0 \n",
"b = 8.0\n",
"c = 6.0\n",
"d = c + (b-a)\n",
"d1 = np.random.uniform(a, b, 100) \n",
"d2 = np.random.uniform(c, d, 100) \n",
"data = d1 + d2\n",
"\n",
"#Setup the priors\n",
"\n",
"pa = pm.Normal(\"pa\", 0.0, .01, value=0)\n",
"pb = pm.Normal(\"pb\", 0.0, .01, value=10)\n",
"pc = pm.Uniform(\"pc\", 5, 10, value=5) # <-- changed prior to break symmetry\n",
"pd = pm.Normal(\"pd\", 0.0, .01, value=10)\n",
"\n",
"\n",
"@pm.observed\n",
"def pdata(value=data, pa=pa, pb=pb, pc=pc, pd=pd):\n",
" pd = pc + (pb - pa) # don't use pd value\n",
" \n",
" # make sure order is acceptible\n",
" if pb < pa or pd < pc:\n",
" return -np.inf\n",
"\n",
" x = value\n",
" pr = \\\n",
" np.where(x < pa+pc, 0,\n",
" np.where(x <= (pa+pb+pc+pd)/2, x - (pa+pc),\n",
" np.where(x <= (pb+pd), (pb+pd) - x,\n",
" 0))) \\\n",
" / (.5 * ((pb+pd) - (pa+pc)) * ((pb-pa) + (pd-pc))/2)\n",
" return np.sum(np.log(pr))\n",
"\n",
"m = pm.MCMC(dict(pa=pa, pb=pb, pc=pc, pd=pd, pdata=pdata))\n",
"m.use_step_method(pm.AdaptiveMetropolis, [m.pa, m.pb, m.pc])\n",
"m.sample(20000, 10000, 10)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [--- 8% ] 1630 of 20000 complete in 0.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------ 16% ] 3315 of 20000 complete in 1.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [--------- 25% ] 5028 of 20000 complete in 1.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [------------ 33% ] 6739 of 20000 complete in 2.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [---------------- 42% ] 8462 of 20000 complete in 2.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------50% ] 10177 of 20000 complete in 3.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------59%-- ] 11895 of 20000 complete in 3.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------68%----- ] 13622 of 20000 complete in 4.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------76%--------- ] 15343 of 20000 complete in 4.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------85%------------ ] 17060 of 20000 complete in 5.0 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------93%--------------- ] 18784 of 20000 complete in 5.5 sec"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"\r",
" [-----------------100%-----------------] 20000 of 20000 complete in 5.9 sec"
]
}
],
"prompt_number": 7
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"print pandas.DataFrame(m.stats()).loc[['mean', 'standard deviation']]"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
" pa pb pc pd\n",
"mean 1.068921 6.471076 7.74488 -0.4584424\n",
"standard deviation 1.477786 1.451357 1.439883 9.85236\n"
]
}
],
"prompt_number": 8
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment