Skip to content

Instantly share code, notes, and snippets.

@junpenglao
Last active February 12, 2018 15:52
Show Gist options
  • Save junpenglao/aa6ea70d55d1c93fc11e45e3bf8666bd to your computer and use it in GitHub Desktop.
Save junpenglao/aa6ea70d55d1c93fc11e45e3bf8666bd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Runing on PyMC3 v3.3\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import pymc3 as pm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import pystan\n",
"import pystan.chains\n",
"from collections import OrderedDict\n",
"import pandas as pd\n",
"\n",
"plt.style.use('seaborn-darkgrid')\n",
"print('Runing on PyMC3 v{}'.format(pm.__version__))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"f1 = '/usr/local/lib/python3.5/dist-packages/pystan/tests/data/blocker.1.csv'\n",
"f2 = '/usr/local/lib/python3.5/dist-packages/pystan/tests/data/blocker.2.csv'\n",
"\n",
"# f1 = '/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/pystan/tests/data/blocker.1.csv'\n",
"# f2 = '/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/pystan/tests/data/blocker.2.csv'\n",
"\n",
"# read csv using numpy\n",
"c1 = np.loadtxt(f1, skiprows=41, delimiter=',')[:, 4:]\n",
"c1_colnames = open(f1, 'r').readlines()[36].strip().split(',')[4:]\n",
"np.testing.assert_equal(c1_colnames[0], 'd')\n",
"c2 = np.loadtxt(f2, skiprows=41, delimiter=',')[:, 4:]\n",
"c2_colnames = open(f2, 'r').readlines()[36].strip().split(',')[4:]\n",
"np.testing.assert_equal(c1_colnames, c2_colnames)\n",
"np.testing.assert_equal(len(c1_colnames), c1.shape[1])\n",
"\n",
"n_samples = len(c1)\n",
"np.testing.assert_equal(n_samples, 1000)\n",
"\n",
"c1 = OrderedDict((k, v) for k, v in zip(c1_colnames, c1.T))\n",
"c2 = OrderedDict((k, v) for k, v in zip(c2_colnames, c2.T))\n",
"\n",
"lst = dict(fnames_oi=c1_colnames, samples=[{'chains': c1}, {'chains': c2}],\n",
" n_save=np.repeat(n_samples, 2), permutation=None,\n",
" warmup=0, warmup2=[0, 0], chains=2, n_flatnames=len(c1))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"odict_keys(['d', 'sigmasq_delta', 'mu.1', 'mu.2', 'mu.3', 'mu.4', 'mu.5', 'mu.6', 'mu.7', 'mu.8', 'mu.9', 'mu.10', 'mu.11', 'mu.12', 'mu.13', 'mu.14', 'mu.15', 'mu.16', 'mu.17', 'mu.18', 'mu.19', 'mu.20', 'mu.21', 'mu.22', 'delta.1', 'delta.2', 'delta.3', 'delta.4', 'delta.5', 'delta.6', 'delta.7', 'delta.8', 'delta.9', 'delta.10', 'delta.11', 'delta.12', 'delta.13', 'delta.14', 'delta.15', 'delta.16', 'delta.17', 'delta.18', 'delta.19', 'delta.20', 'delta.21', 'delta.22', 'delta_new', 'sigma_delta'])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"slst = lst['samples'][0]['chains']\n",
"slst.keys()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"param_names = list(slst.keys())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"m = lst['chains']\n",
"\n",
"ns_save = lst['n_save']\n",
"ns_warmup2 = lst['warmup2']\n",
"ns_kept = [s - w for s, w in zip(lst['n_save'], lst['warmup2'])]\n",
"\n",
"n_samples = min(ns_kept)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"n_eff = [\n",
" 466.099, 136.953, 1170.390, 541.256,\n",
" 518.051, 589.244, 764.813, 688.294,\n",
" 323.777, 502.892, 353.823, 588.142,\n",
" 654.336, 480.914, 176.978, 182.649,\n",
" 642.389, 470.949, 561.947, 581.187,\n",
" 446.389, 397.641, 338.511, 678.772,\n",
" 1442.250, 837.956, 869.865, 951.124,\n",
" 619.336, 875.805, 233.260, 786.568,\n",
" 910.144, 231.582, 907.666, 747.347,\n",
" 720.660, 195.195, 944.547, 767.271,\n",
" 723.665, 1077.030, 470.903, 954.924,\n",
" 497.338, 583.539, 697.204, 98.421\n",
"]\n",
"\n",
"ess = []\n",
"for i in range(len(n_eff)):\n",
" ess.append(pystan.chains.ess(lst, i))\n",
" np.testing.assert_almost_equal(ess[i], n_eff[i], 2)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"ess2 = []\n",
"for varname in param_names:\n",
" trace_values = [lst['samples'][im]['chains'][varname] for im in range(m)]\n",
" ess2.append(pm.effective_n(trace_values))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Target</th>\n",
" <th>PyStan</th>\n",
" <th>PyMC3</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>466.099</td>\n",
" <td>466.098810</td>\n",
" <td>465.994306</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>136.953</td>\n",
" <td>136.952532</td>\n",
" <td>136.941283</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1170.390</td>\n",
" <td>1170.393732</td>\n",
" <td>1170.222485</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>541.256</td>\n",
" <td>541.255659</td>\n",
" <td>541.199238</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>518.051</td>\n",
" <td>518.051325</td>\n",
" <td>517.996363</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>589.244</td>\n",
" <td>589.243546</td>\n",
" <td>589.163026</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>764.813</td>\n",
" <td>764.812721</td>\n",
" <td>764.667763</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>688.294</td>\n",
" <td>688.293542</td>\n",
" <td>688.212347</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>323.777</td>\n",
" <td>323.777181</td>\n",
" <td>323.743779</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>502.892</td>\n",
" <td>502.891905</td>\n",
" <td>502.837470</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>353.823</td>\n",
" <td>353.822729</td>\n",
" <td>353.793908</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>588.142</td>\n",
" <td>588.141518</td>\n",
" <td>588.045040</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>654.336</td>\n",
" <td>654.335742</td>\n",
" <td>654.216984</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>480.914</td>\n",
" <td>480.914331</td>\n",
" <td>480.857175</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>176.978</td>\n",
" <td>176.977626</td>\n",
" <td>176.966982</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>182.649</td>\n",
" <td>182.648542</td>\n",
" <td>182.634739</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>642.389</td>\n",
" <td>642.389260</td>\n",
" <td>642.278555</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>470.949</td>\n",
" <td>470.949244</td>\n",
" <td>470.894658</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>561.947</td>\n",
" <td>561.946880</td>\n",
" <td>561.868481</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>581.187</td>\n",
" <td>581.186630</td>\n",
" <td>581.129337</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>446.389</td>\n",
" <td>446.389238</td>\n",
" <td>446.333580</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>397.641</td>\n",
" <td>397.641095</td>\n",
" <td>397.597217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>338.511</td>\n",
" <td>338.510536</td>\n",
" <td>338.485896</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>678.772</td>\n",
" <td>678.771523</td>\n",
" <td>678.661293</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>1442.250</td>\n",
" <td>1442.252440</td>\n",
" <td>1442.042092</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>837.956</td>\n",
" <td>837.955698</td>\n",
" <td>837.849683</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>869.865</td>\n",
" <td>869.865497</td>\n",
" <td>869.710633</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>951.124</td>\n",
" <td>951.123727</td>\n",
" <td>950.933339</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>619.336</td>\n",
" <td>619.336051</td>\n",
" <td>619.243131</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>875.805</td>\n",
" <td>875.804964</td>\n",
" <td>875.695420</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>233.260</td>\n",
" <td>233.260458</td>\n",
" <td>233.239514</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>786.568</td>\n",
" <td>786.567760</td>\n",
" <td>786.463445</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>910.144</td>\n",
" <td>910.143753</td>\n",
" <td>909.931239</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33</th>\n",
" <td>231.582</td>\n",
" <td>231.581848</td>\n",
" <td>231.564143</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34</th>\n",
" <td>907.666</td>\n",
" <td>907.665813</td>\n",
" <td>907.518937</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35</th>\n",
" <td>747.347</td>\n",
" <td>747.346732</td>\n",
" <td>747.212315</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36</th>\n",
" <td>720.660</td>\n",
" <td>720.660283</td>\n",
" <td>720.555130</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37</th>\n",
" <td>195.195</td>\n",
" <td>195.195121</td>\n",
" <td>195.176669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38</th>\n",
" <td>944.547</td>\n",
" <td>944.546821</td>\n",
" <td>944.371209</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39</th>\n",
" <td>767.271</td>\n",
" <td>767.270590</td>\n",
" <td>767.143768</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40</th>\n",
" <td>723.665</td>\n",
" <td>723.664570</td>\n",
" <td>723.559907</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41</th>\n",
" <td>1077.030</td>\n",
" <td>1077.025450</td>\n",
" <td>1076.857749</td>\n",
" </tr>\n",
" <tr>\n",
" <th>42</th>\n",
" <td>470.903</td>\n",
" <td>470.902902</td>\n",
" <td>470.832818</td>\n",
" </tr>\n",
" <tr>\n",
" <th>43</th>\n",
" <td>954.924</td>\n",
" <td>954.924046</td>\n",
" <td>954.748152</td>\n",
" </tr>\n",
" <tr>\n",
" <th>44</th>\n",
" <td>497.338</td>\n",
" <td>497.338153</td>\n",
" <td>497.289064</td>\n",
" </tr>\n",
" <tr>\n",
" <th>45</th>\n",
" <td>583.539</td>\n",
" <td>583.538694</td>\n",
" <td>583.446292</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>697.204</td>\n",
" <td>697.204057</td>\n",
" <td>697.077568</td>\n",
" </tr>\n",
" <tr>\n",
" <th>47</th>\n",
" <td>98.421</td>\n",
" <td>98.421158</td>\n",
" <td>98.414273</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Target PyStan PyMC3\n",
"0 466.099 466.098810 465.994306\n",
"1 136.953 136.952532 136.941283\n",
"2 1170.390 1170.393732 1170.222485\n",
"3 541.256 541.255659 541.199238\n",
"4 518.051 518.051325 517.996363\n",
"5 589.244 589.243546 589.163026\n",
"6 764.813 764.812721 764.667763\n",
"7 688.294 688.293542 688.212347\n",
"8 323.777 323.777181 323.743779\n",
"9 502.892 502.891905 502.837470\n",
"10 353.823 353.822729 353.793908\n",
"11 588.142 588.141518 588.045040\n",
"12 654.336 654.335742 654.216984\n",
"13 480.914 480.914331 480.857175\n",
"14 176.978 176.977626 176.966982\n",
"15 182.649 182.648542 182.634739\n",
"16 642.389 642.389260 642.278555\n",
"17 470.949 470.949244 470.894658\n",
"18 561.947 561.946880 561.868481\n",
"19 581.187 581.186630 581.129337\n",
"20 446.389 446.389238 446.333580\n",
"21 397.641 397.641095 397.597217\n",
"22 338.511 338.510536 338.485896\n",
"23 678.772 678.771523 678.661293\n",
"24 1442.250 1442.252440 1442.042092\n",
"25 837.956 837.955698 837.849683\n",
"26 869.865 869.865497 869.710633\n",
"27 951.124 951.123727 950.933339\n",
"28 619.336 619.336051 619.243131\n",
"29 875.805 875.804964 875.695420\n",
"30 233.260 233.260458 233.239514\n",
"31 786.568 786.567760 786.463445\n",
"32 910.144 910.143753 909.931239\n",
"33 231.582 231.581848 231.564143\n",
"34 907.666 907.665813 907.518937\n",
"35 747.347 747.346732 747.212315\n",
"36 720.660 720.660283 720.555130\n",
"37 195.195 195.195121 195.176669\n",
"38 944.547 944.546821 944.371209\n",
"39 767.271 767.270590 767.143768\n",
"40 723.665 723.664570 723.559907\n",
"41 1077.030 1077.025450 1076.857749\n",
"42 470.903 470.902902 470.832818\n",
"43 954.924 954.924046 954.748152\n",
"44 497.338 497.338153 497.289064\n",
"45 583.539 583.538694 583.446292\n",
"46 697.204 697.204057 697.077568\n",
"47 98.421 98.421158 98.414273"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_neff = pd.DataFrame(data=dict(Target=n_eff, PyStan=ess, PyMC3=ess2),\n",
" columns=['Target', 'PyStan', 'PyMC3'])\n",
"df_neff"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment