Created
June 24, 2014 21:56
-
-
Save aflaxman/cb43e6b876ec78cb7ec4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:0d5c95c7ee7305f3517824521e605e051812eb8f3d713fb76fe2ce71bf4e10ff" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"import numpy as np, pymc as pm, pandas" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# set random seed for reproducibility\n", | |
"np.random.seed(12345)\n", | |
"\n", | |
"# Generate the synthetic data\n", | |
"a = 2.0 \n", | |
"b = 8.0\n", | |
"c = 6.0\n", | |
"d = c + (b-a)\n", | |
"d1 = np.random.uniform(a, b, 100) \n", | |
"d2 = np.random.uniform(c, d, 100) \n", | |
"data = d1 + d2" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"#Setup the priors\n", | |
"\n", | |
"pa = pm.Normal(\"pa\", 0.0, .01, value=0)\n", | |
"pb = pm.Normal(\"pb\", 0.0, .01, value=10)\n", | |
"pc = pm.Normal(\"pc\", 0.0, .01, value=0)\n", | |
"pd = pm.Normal(\"pd\", 0.0, .01, value=10)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# if data was just d1, this would be more familiar\n", | |
"data = d1\n", | |
"\n", | |
"# here is how you could fit it with an \"observed\" stochastic\n", | |
"@pm.observed\n", | |
"def pdata(value=data, pa=pa, pb=pb, pc=pc, pd=pd):\n", | |
" return pm.uniform_like(value, pa, pb)\n", | |
"\n", | |
"m = pm.MCMC(dict(pa=pa, pb=pb, pc=pc, pd=pd, pdata=pdata))\n", | |
"m.sample(100000, 50000, 50)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 3159 of 100000 complete in 0.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-- 6% ] 6297 of 100000 complete in 1.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [--- 9% ] 9445 of 100000 complete in 1.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [---- 12% ] 12604 of 100000 complete in 2.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [----- 15% ] 15769 of 100000 complete in 2.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------- 18% ] 18928 of 100000 complete in 3.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-------- 22% ] 22093 of 100000 complete in 3.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [--------- 25% ] 25258 of 100000 complete in 4.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [---------- 28% ] 28380 of 100000 complete in 4.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [----------- 31% ] 31509 of 100000 complete in 5.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------------- 34% ] 34652 of 100000 complete in 5.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-------------- 37% ] 37798 of 100000 complete in 6.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [--------------- 40% ] 40943 of 100000 complete in 6.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [---------------- 44% ] 44074 of 100000 complete in 7.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------47% ] 47182 of 100000 complete in 7.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------50% ] 50327 of 100000 complete in 8.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------53% ] 53423 of 100000 complete in 8.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------56%- ] 56516 of 100000 complete in 9.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------59%-- ] 59561 of 100000 complete in 9.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------62%--- ] 62629 of 100000 complete in 10.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------65%---- ] 65704 of 100000 complete in 10.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------68%------ ] 68783 of 100000 complete in 11.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------71%------- ] 71882 of 100000 complete in 11.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------74%-------- ] 74975 of 100000 complete in 12.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------78%--------- ] 78066 of 100000 complete in 12.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------81%---------- ] 81153 of 100000 complete in 13.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------84%------------ ] 84261 of 100000 complete in 13.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------87%------------- ] 87335 of 100000 complete in 14.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------90%-------------- ] 90420 of 100000 complete in 14.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------93%--------------- ] 93452 of 100000 complete in 15.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------96%---------------- ] 96555 of 100000 complete in 15.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------99%----------------- ] 99632 of 100000 complete in 16.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------100%-----------------] 100000 of 100000 complete in 16.1 sec" | |
] | |
} | |
], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"print pandas.DataFrame(m.stats()).loc[['mean', 'standard deviation']]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" pa pb pc pd\n", | |
"mean 1.992644 8.023612 -0.3094084 0.09822238\n", | |
"standard deviation 0.0577922 0.05885509 10.1256 10.20577\n" | |
] | |
} | |
], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"#Setup the priors\n", | |
"\n", | |
"pa = pm.Normal(\"pa\", 0.0, .01, value=0)\n", | |
"pb = pm.Normal(\"pb\", 0.0, .01, value=10)\n", | |
"pc = pm.Normal(\"pc\", 0.0, .01, value=0)\n", | |
"pd = pm.Normal(\"pd\", 0.0, .01, value=10)\n", | |
"\n", | |
"# In general, you can use this pattern with a more complicated\n", | |
"# calculation of the log-likelihood, e.g.\n", | |
"\n", | |
"@pm.observed\n", | |
"def pdata(value=data, pa=pa, pb=pb, pc=pc, pd=pd):\n", | |
" logpr = 0.\n", | |
" # code to calculate logpr from data and parameters\n", | |
" return logpr\n", | |
"\n", | |
"\n", | |
"# Use MCMC to sample and obtain traces\n", | |
"m = pm.MCMC(dict(pa=pa, pb=pb, pc=pc, pd=pd, pdata=pdata))\n", | |
"m.sample(100000, 50000, 50)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [- 3% ] 3943 of 100000 complete in 0.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [--- 7% ] 7910 of 100000 complete in 1.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [---- 11% ] 11870 of 100000 complete in 1.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------ 15% ] 15811 of 100000 complete in 2.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------- 19% ] 19724 of 100000 complete in 2.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-------- 23% ] 23673 of 100000 complete in 3.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [---------- 27% ] 27650 of 100000 complete in 3.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------------ 31% ] 31594 of 100000 complete in 4.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------------- 35% ] 35525 of 100000 complete in 4.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-------------- 39% ] 39443 of 100000 complete in 5.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [---------------- 43% ] 43413 of 100000 complete in 5.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------47% ] 47386 of 100000 complete in 6.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------51% ] 51372 of 100000 complete in 6.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------55%- ] 55334 of 100000 complete in 7.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------59%-- ] 59276 of 100000 complete in 7.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------63%---- ] 63174 of 100000 complete in 8.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------67%----- ] 67071 of 100000 complete in 8.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------71%------ ] 71019 of 100000 complete in 9.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------74%-------- ] 74963 of 100000 complete in 9.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------78%--------- ] 78895 of 100000 complete in 10.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------82%----------- ] 82802 of 100000 complete in 10.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------86%------------ ] 86732 of 100000 complete in 11.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------90%-------------- ] 90651 of 100000 complete in 11.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------94%--------------- ] 94571 of 100000 complete in 12.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------98%----------------- ] 98495 of 100000 complete in 12.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------100%-----------------] 100000 of 100000 complete in 12.7 sec" | |
] | |
} | |
], | |
"prompt_number": 6 | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"To keep this example simple, let us restrict the model to have $c-a = b-d$. Then $d1_i+d2_i$ has a symmetric triangular distribution, with support $[a+c, b+d]$. In other words,\n", | |
"$$\n", | |
"p(data_i=x|a,b,c,d) \\propto \\begin{cases}\n", | |
"x - (a+c) &\\quad \\text{if } a+c \\leq x \\leq \\frac{a+b+c+d}{2};\\\\\n", | |
"(b+d) - x &\\quad \\text{if } \\frac{a+b+c+d}{2} \\leq x \\leq b+d;\\\\\n", | |
"0 &\\quad {\\text{otherwise.}}\n", | |
"\\end{cases}\n", | |
"$$\n", | |
"\n", | |
"Did I say \"simple\"?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# set random seed for reproducibility\n", | |
"np.random.seed(12345)\n", | |
"\n", | |
"# Generate the synthetic data\n", | |
"a = 2.0 \n", | |
"b = 8.0\n", | |
"c = 6.0\n", | |
"d = c + (b-a)\n", | |
"d1 = np.random.uniform(a, b, 100) \n", | |
"d2 = np.random.uniform(c, d, 100) \n", | |
"data = d1 + d2\n", | |
"\n", | |
"#Setup the priors\n", | |
"\n", | |
"pa = pm.Normal(\"pa\", 0.0, .01, value=0)\n", | |
"pb = pm.Normal(\"pb\", 0.0, .01, value=10)\n", | |
"pc = pm.Uniform(\"pc\", 5, 10, value=5) # <-- changed prior to break symmetry\n", | |
"pd = pm.Normal(\"pd\", 0.0, .01, value=10)\n", | |
"\n", | |
"\n", | |
"@pm.observed\n", | |
"def pdata(value=data, pa=pa, pb=pb, pc=pc, pd=pd):\n", | |
" pd = pc + (pb - pa) # don't use pd value\n", | |
" \n", | |
" # make sure order is acceptible\n", | |
" if pb < pa or pd < pc:\n", | |
" return -np.inf\n", | |
"\n", | |
" x = value\n", | |
" pr = \\\n", | |
" np.where(x < pa+pc, 0,\n", | |
" np.where(x <= (pa+pb+pc+pd)/2, x - (pa+pc),\n", | |
" np.where(x <= (pb+pd), (pb+pd) - x,\n", | |
" 0))) \\\n", | |
" / (.5 * ((pb+pd) - (pa+pc)) * ((pb-pa) + (pd-pc))/2)\n", | |
" return np.sum(np.log(pr))\n", | |
"\n", | |
"m = pm.MCMC(dict(pa=pa, pb=pb, pc=pc, pd=pd, pdata=pdata))\n", | |
"m.use_step_method(pm.AdaptiveMetropolis, [m.pa, m.pb, m.pc])\n", | |
"m.sample(20000, 10000, 10)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [--- 8% ] 1630 of 20000 complete in 0.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------ 16% ] 3315 of 20000 complete in 1.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [--------- 25% ] 5028 of 20000 complete in 1.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [------------ 33% ] 6739 of 20000 complete in 2.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [---------------- 42% ] 8462 of 20000 complete in 2.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------50% ] 10177 of 20000 complete in 3.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------59%-- ] 11895 of 20000 complete in 3.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------68%----- ] 13622 of 20000 complete in 4.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------76%--------- ] 15343 of 20000 complete in 4.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------85%------------ ] 17060 of 20000 complete in 5.0 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------93%--------------- ] 18784 of 20000 complete in 5.5 sec" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"\r", | |
" [-----------------100%-----------------] 20000 of 20000 complete in 5.9 sec" | |
] | |
} | |
], | |
"prompt_number": 7 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"print pandas.DataFrame(m.stats()).loc[['mean', 'standard deviation']]" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
" pa pb pc pd\n", | |
"mean 1.068921 6.471076 7.74488 -0.4584424\n", | |
"standard deviation 1.477786 1.451357 1.439883 9.85236\n" | |
] | |
} | |
], | |
"prompt_number": 8 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment