Skip to content

Instantly share code, notes, and snippets.

@AustinRochford
Last active February 9, 2017 14:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AustinRochford/256d044ae8cf9f67d1fc92897fafc48d to your computer and use it in GitHub Desktop.
Save AustinRochford/256d044ae8cf9f67d1fc92897fafc48d to your computer and use it in GitHub Desktop.
Random Effects ADVI Two Observables
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"import pymc3 as pm\n",
"import scipy as sp\n",
"import seaborn as sns\n",
"from theano import tensor as tt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"np.random.seed(1234567890)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"N = 100\n",
"\n",
"RE_MEAN = 0.\n",
"RE_SCALE = 0.5\n",
"\n",
"BETA = 1."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"random_effect = np.random.normal(RE_MEAN, RE_SCALE, size=N)\n",
"x = np.random.uniform(-5, 5, size=N)\n",
"\n",
"p = sp.special.expit(random_effect + BETA * x)\n",
"y = np.random.binomial(1, p, size=N)\n",
"z = 8 * p + np.random.normal(size=N)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.5/site-packages/matplotlib/font_manager.py:1297: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans\n",
" (prop.get_family(), self.defaultFamily[fontext]))\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAFlCAYAAADYnoD9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X94VOWd9/FPMmMsMSQFl0x+DREr7bZC/XE9Unk02EYC\nG1IXBXyWxXUv2/LEcrVbrYvEQh6YYoLFRWuea3cp2RTsatm2uq1cJaaCWEtaLN3Hrhe1em1LJTIh\nyfij0RCGipPk+cOdkB9nfmXOzDlz5v36q8k5k7nn28g3933u+/vNGRkZGREAALBMrtUDAAAg25GM\nAQCwGMkYAACLkYwBALAYyRgAAIuRjAEAsJjbqjd+883TVr21qWbMyFd/f9DqYTgKMTUfMTUfMTWf\n02M6a9b0iNeYGSfJ7XZZPQTHIabmI6bmI6bmy+aYkowBALAYyRgAAIuRjAEAsBjJGAAAi5GMAQCw\nGMkYAACLkYwBALAYyRgAAIuRjAEAsBjJGAAAi5GMAQCwWMxk/LWvfU0LFy7UZz/7WcPrIyMjampq\nUk1NjW666Sb99re/NX2QAAA4WcxkvGLFCrW1tUW8fvjwYXV1denAgQO6//775fP5zBwfAACmCAaD\nevXV3+rVV19RMGiv7lAxk/E111yjoqKiiNcPHTqkm2++WTk5Obryyis1MDCgN954w9RBAgAwVaFQ\nSBs23K1PfOJS3XDDQt1ww7WaN+8ybdx4r0KhkKQPEvWJE6+NJumJX6da0v2MA4GASkpKRr8uKSlR\nIBBQcXFxsj8aAJDBgsGgAoE+eTwlys/PT9trxxoYGFBtbbV+//vfjfv+4OCg2tp2SZJyc3PV0dGu\nU6e6VV5erqKiD+udd95RT88plZdXqLa2Tj5fs9zupFNmREn/5JGRkUnfy8nJifm6GTPyHdO7MlrD\naEwNMTUfMTUfMTUWCoW0fv167du3TydPntTs2bO1fPly7dixI2ZCmzFj2pRfGwwG1dvbq9LSUuXl\n5Wn9+vX69re/rcHBwYiv+f739+r06dOjX/v9fvn9/jFfn1Rr605Nm5anRx55JM4IJC7pZFxSUqK+\nvr7Rr/v6+uKaFff322u9fqpmzZquN988HftGxI2Ymo+Yms/uMQ3PLAsLCzUwMJD0DDMRjY0Nam3d\nOfp1V1eXWlpadPbsOTU1bY/4ulmzputLX/pKwq8NhULy+TaNmd1WqKioSC+//JuYYx2biKP54Q+f\n0le/+rWkYhjtj7ekjzZVV1frqaee0sjIiF566SVNnz6dJWoAsEgoFFJjY4Oqqq7Rpz51pebNm6tP\nfepKXX/9Nbrvvnt0/PjvU/ocNBgMqqOj3fBaR8fTk9577LPZRF8b5vNtUmvrTvn9JzU8PCy//2Rc\niTgRPT3dCgT6Yt84RTFnxvfcc49+9atfqb+/X4sWLdLf/d3fjT7w/uu//mvdcMMN+tnPfqaamhpN\nmzZN27ZtS9lgAQDRhRNT2NDQkCSpu9uv3bvbtHt3m7ze2aY9B534bDcQ6NOpU92G94YT2pw5l2pg\nYECbNm3QL37ROfps9sYbq9Xd7Y/52onvHymBx6OgoCDqMnZYaWmZPJ6SmPdNVcz/Fx5++OGo13Ny\ncrRlyxbTBgQAmJp4E1P4OaikqMvG0RgtDdfW1mnDhk0qL6+Q339y0mvKyip08cV/psbGBu3d+9i4\nJOj3n9Sjjz4aMTmWlVUYJsNoyT8at9utO+74giSNbuSK5rrrqlK6zJ+6rWEAgLRKNDF1dDytjRu3\nTCnJTJyBj03wtbV1466F1dYu04MPNhteO894A3Bt7TLDcXo8JRGTfzTPPPO85s//pEKhkHJzc/X0\n0/sjzsoLCqarufnBhH5+oiiHCQApkuhZ1WTPtoYTU7ym+hw01rPdDRs2qb5+nbzeSrlcLnm9laqv\nX6cNGzbFnLmfPRvU6tVrJr3W52s2vD8/P1+1tXWG19zuCwy/7/VW6iMfuey/73GrqWm7fv7z/9Dq\n1WsM71+z5m9UWFgYddzJYmYMACaLtIQb6RltovdHEk5M0Wee55WUTO05aKznwm+//ZaamrZr48Yt\n454nnzjxWsyZe1lZhb7xjYdH3yeeXeDhRN3R8bR6erpVVlah2tplGh4eNlyCNppl5+fn6+GH/1GF\nhUWTfk6kPwTMlDNidFA4Dex8JCARdj/ekImIqfmIqfmixXTi0Z6w+vp1hs9oE70/mnBib2//cczE\nV1BQoDVrbk846QeDQVVVLTBcGvZ6K9XZedQwgUZ7XdhUPvPYnz82gZ//I2dyco32ec0qODJRSo82\nAQDOm8rRnqkc54kkvOy6d++Tys2N/k/84OCgWlt3yufblNB7RFsajvRsN9brCgqm66677kpqFpqf\nn685cy4dff9wLDo7j+rIkRfV2XlUTU3bY/7hMfHnpAPJGABMFM/RnmTuj1dl5SVxPz+eStL3+ZoN\nnwvHSqYTX1de7tXq1Wv00kuv6pFHHklJyUkrkmuiWKZOEst/5iOm5iOm5osU00SXcKe65BuPSMvf\nE7lcLh058uKkM7zxmOqSrtHrnP57yjI1AKRJoku4U13yjcfYWWhubq5cLuN+AJHO8MZjqrPOTJit\nppPLZ1ED4mDwnBVva7qLLrrQMZ/FLoip+Yip+aLFdNGiz+j06QG98cabOnNmUBUVs7V69Rr5fM2G\nz3ETvT9eubm5qq6u0e2336HVq2/T+++H9NJLv5503+rVa7R06bIpv49ZnP57etFFF0a8xjJ1kpy+\nrGIFYmo+Ymq+eGI6cSk21pJuqnbxhk11d3G6xuj039Noy9Qk4yQ5/ZfHCsTUfMTUfInE1KxzxGaZ\nSkJNx2dw+u9ptGRM0Q8AiMKMmWC00pFTPVObjPDz2kTY7TM4DRu4AMDA+VaEC7Rw4dWqqlqgxsaG\n0a518TL7HLEVnPAZ7I5kDAAGjHrkTqVARqrOEaeTEz6D3ZGMAUDmNLk3Eq15Q6QjRck2jDDbVD4D\nEkMyBpDVjJaj77vvHtNmgomcIzZradxsqTwLjQ+wgQtAVjPamPS97+1NuMl99Pcw7io0sXSknTdJ\nxfsZMDUcbUqS07fiW4GYmo+YGotWirKgYLoGByfHLNxV6KKLXHr55d8ntMs62s7sVJbFNBPnjKeO\ncpgAYCDaxqRg8Iz+6q8mN7lvbPy6GhsbdPnllye8lBytBGS0sZw65dfrr59I7MOlCGUsU4NlagBZ\nK7wxyWg2Wl7u1fbtk5vcT2y+YLSUPJXZY7SxDA8P67bb/peWLfusZYVCkFrMjAFkrXg2Jo2dCcba\nZT0wMDDlDVjRxiJJ3d3+KR2tQmagUUSSnF7Y3ArE1HzENLJEmjR0d/v1zW/+g4y22pw5M6hAoFff\n+c4eDQy8q5GREQ0MvKsXX/x/On16QNXVNXGPpa8voNOnBwzveeONN3X77XfoggsumNoHtjGn/57S\nKCKFnL7hwArE1HzENLZ4lpajbbKqqPBqZGTE8LlvvBuwwmN45513VFtbreHh4Un3JNN72O6c/nvK\nBi4AiCGejUnRlpKvu65Kvb09htdinU2eeL74C1+4PeI4KLLhTOwCAIAEhM/VHjjQIb/fP3redsOG\nTTpy5OeGs+ZYCXTi+eLubn/Eeymy4UwkYwBIgNvtVlPTdn3zmzsmnTOura0bl1TDoiXQaJvCCgqm\nq6jow+rr66HIhsORjAFkvFQWoojEqA3hVKpURTtffPZsUO3tB/WhD30orZ8N6UcyBpCx0tHwPhHh\nWfPGjVvi/uMg2vnisrIKVVZeQhLOAmzgApCxzGpzaLZEqlTRhAESyRhAhnJSw3ufr1n19esmld7k\n+XD2YJkaQEaKp+F9ppzFncryNpyFmTGAjOTEhvc0YcheJGMAGYlnrXASlqkBZCwa3sMpqE2dJKfX\nUrUCMTWf02NqxTljp8fUCk6PabTa1MyMAWQ8owIcQCbhmTEAABYjGQMAYDGSMQAAFiMZA8gYwWBQ\nJ068llHVtYB4kIwB2F4oFFJjY4OqqhZo4cKrVVW1QI2NDQqFQlYPDTAFu6kB2F64IURYuCGEJDU1\nbbdqWIBpmBkDsDU7N4Rg2RxmIRkDsLV4GkKkWygU0t13382yOUzDMjUAWws3hPD7T066ZlVDCJbN\nYTZmxgBszW4NIey8bI7MxcwYgO3ZqSGEk/oowz5IxgBsz+12q6lpuzZu3JL2hhAT2XHZfCIrGmcg\nOSxTA8gY4YYQViYYuy2bj8V57MzFzBhAxjKaAaZjVujzNWvatDz98IdPWb5sPn5cbCzLVHH1Mz58\n+LCam5s1PDysW2+9VfX19eOu9/T0qKGhQadPn9bQ0JDWr1+vG264IerPdErPSqf337QCMTWf02Ia\nCoXk821SR0e7Tp3qVnl5hZYurZUkPfNMx+j3amvr5PM1y+02f94xa9Z0vf56wDbLwcFgUFVVCwyX\nz73eSnV2HrV8jLE47fd0oqT6GQ8NDWnr1q3as2ePPB6PVq1aperqal122WWj9+zcuVO1tbVas2aN\njh8/rvr6ej333HPmjB4AJjCaAba17Rp3TzpmhXbqo8zGsswW85nxsWPHVFlZKa/Xq7y8PNXV1enQ\noUPj7snJydHg4KAk6fTp0youLk7NaAFkvWhHi4xky3Gj8MYyI3bZWIbIYs6MA4GASkrO/5/o8Xh0\n7Nixcfd8+ctf1he+8AU9/vjjOnv2rPbs2RPzjWfMyJfb7ZrCkO0n2tIDpoaYms8pMf3DH96IOAM0\n0tPTrVBoULNmeUwfi71iOl0rVtyilpaWSVdWrLhZlZXmf/5UsFdM0ydmMjZ6pJyTkzPu6/b2dt1y\nyy36/Oc/r//8z//Uhg0btH//fuXmRp549/c74y9Vpz/jsAIxNZ+TYup2F0Q8WmSkrKxCbneB6Z/f\njjFtaNiis2fPTTqP3dCwxXZjNWLHmJop2h8aMZepS0pK1Nd3vvZrIBCYtAz95JNPqrb2g80TV111\nld577z319/dPdbwAEFG0o0VGrD5ulE7h89idnUd15MiL6uw8qqam7SnZwAZzxUzG8+fPV1dXl/x+\nv86dO6f29nZVV1ePu6e0tFQvvPCCJOkPf/iD3nvvPc2cOTM1IwaQ9Xy+ZtXXr5PXWymXyyWvt1Jr\n196ptWvvHPe9+vp1lh83soIdzmMjMXEdbfrZz36mbdu2aWhoSCtXrtS6devU0tKiefPm6cYbb9Tx\n48fV2NioYDConJwc3Xvvvbr++uuj/kynLEU4fVnFCsTUfE6NqVXnjCXnxtRKTo9ptGXquJJxKjgl\n4E7/5bECMTUfMTUfMTWf02Oa1DNjAACQWiRjAAAsRjIGAMBiJGMAACxGMgYAwGIkYwAALEYyBuBo\nwWBQJ068lhXNIpC5SMYAHCkUCqmxsUFVVQu0cOHVqqpaoMbGBoVCIauHBkxCwVIAjmTU8zjV/Y2B\nqWJmDMBxovU8zpb+xsgsJGMAjhMI9EXsedzT061AoM/wGmAVkjHgQFZsWrLTRimPp0Tl5RWG18rK\nKuTxlKR5REB0JGPAQazYtGT0nvfdd4+OH/+9ZYk5Ws/jbOpvjMxB16YkOb3LiBWI6dQ1NjaM27QU\ndtddd2nTpvvT+p6S5PXOVm1tnXy+5rQ3uA+FQvL5Nqmj42n19HSrrKxCtbXLTBsLv6fmc3pMaaGY\nQk7/5bECMZ2aYDCoqqoF8vtPTrp2ySWX6Pnnf2n6jDDae45VX7/Osh3MqepvzO+p+ZweU1ooAlkg\n2qYlv9+fkk1L0d5zLCt3MOfn52vOnEtZmoatkYwBh4i2acnr9aZk01K09xyLHcxAdCRjwCGibVpa\nvnx5SmaG0d5zLHYwA9FRgQtwEJ+vWZImbVrasWOH+vvPpvw9/f7XDe9hBzMQHRu4kuT0DQdWIKbJ\nm7hpKR0xDQaDOnWqW21tu/TsswdSsoPZTvg9NZ/TYxptA5ez/usAIOn8pqV0v+fcuR/V9u0PpWwH\nM+BUPDMGYDo77mC2U4UwYCKSMQBHo5UiMgHL1AAcjVaKyATMjAE4Fq0UkSlIxgAci1aKyBQkYwCO\nRStFZAqSMQDHopUiMgUbuAA4WqSqZOHvA3ZABa4kOb1ijBWIqfmIqfmtFImp+ZweUypwAch6VlQl\nA+LFM2MAACxGMgYAwGIkYwAALEYyBgDAYiRjAAAsRjIGAMBiJGMAaUVfYWAykjGAtKCvMBAZRT8A\npAV9hYHImBkDSDn6CgPRkYwBpBx9hYHoSMYAUo6+wkB0JGMAKUdfYSA6NnABSAv6CgOR0c84SU7v\nv2kFYmo+O8U00b7CZvchNoudYuoUTo9ptH7GLFMDSKtwX+FYiZVzycgmLFMDsCXOJSObMDMGYDuc\nS0a2iSsZHz58WEuXLlVNTY1aW1sN73n66ae1bNky1dXV6e///u9NHSSA7MK5ZGSbmMvUQ0ND2rp1\nq/bs2SOPx6NVq1apurpal1122eg9XV1dam1t1b/927+pqKhIb7/9dkoHDcDZwueS/f6Tk65xLhlO\nFHNmfOzYMVVWVsrr9SovL091dXU6dOjQuHt+8IMf6LbbblNRUZEk6eKLL07NaAFkBc4lI9vEnBkH\nAgGVlJz/K9Tj8ejYsWPj7unq6pIkrV69WsPDw/ryl7+sRYsWmTtSAFmFc8nIJjGTsdEx5JycnHFf\nDw0N6fXXX9djjz2mvr4+3Xbbbdq/f78KCwsj/twZM/LldrumMGT7iXZ2DFNDTM2XiTHdteufFQwG\n1dvbq9LSUtvNiDMxpnaXrTGNmYxLSkrU13d+s0QgEFBxcfG4ezwej6688kpdcMEF8nq9mjNnjrq6\nuvTJT34y4s/t73fGbkinH1K3AjE1X6bHtLCwWGfODOnMGft8hkyPqR05PaZJFf2YP3++urq65Pf7\nde7cObW3t6u6unrcPYsXL9bRo0clSX/84x/V1dUlr9eb5LABAMgOMWfGbrdbmzdv1tq1azU0NKSV\nK1dq7ty5amlp0bx583TjjTeqqqpKv/jFL7Rs2TK5XC5t2LBBM2bMSMf4AQDIeNSmTpLTl1WsQEzN\nR0zNR0zN5/SYUpsaAAAbIxkDAGAxkjEAABYjGQMAYDGSMQAAFiMZAwBgMZIxAAAWIxkDAGAxkjEA\nABYjGQMAYDGSMQAAFiMZAwBgMZIxkGWCwaBOnHhNwaAzeooDTkAyBrJEKBRSY2ODqqoWaOHCq1VV\ntUCNjQ0KhUJWDw3IejH7GQNwBp9vk1pbd45+7fefHP26qWm7VcMCIGbGQFYIBoPq6Gg3vNbR8TRL\n1oDFSMZAFujt7dWpU92G13p6uhUI9KV5RADGIhkDWaC0tFTl5RWG18rKKuTxlKR5RADGIhkDWSA/\nP1+1tXWG12prlyk/Pz/NIwIwFhu4gCzh8zVL+uAZcU9Pt8rKKlRbu2z0+wCskzMyMjJixRu/+eZp\nK97WdLNmTXfMZ7ELYmq+sTENBoMKBPrk8ZQwI04Cv6fmc3pMZ82aHvEaM2Mgy+Tn52vOnEutHgaA\nMXhmDACAxUjGAABYjGQMZBlqUwP2QzIGsgS1qQH7YgMXYHNm7X6mNjVgX8yMAZsycyZLbWrA3pgZ\nAzZl5kw2ntrUHHcCrMPMGLAhs2ey1KYG7I1kDNhQINBnapclalMD9sYyNWBDHk+Jyssr5PefnHRt\nqjNZalMD9kUyBmwoPJMd+8w4bKozWbfbraam7dq4cQu1qQGbIRkDNpWqmSy1qQH7oWtTkpzeZcQK\nxHQ8M84ZE1PzEVPzOT2mdG0CMhgzWcD52E0NAIDFSMYAAFiMZAxkCLotAc5FMgZsjm5LgPOxgQuw\nObotAc7HzBiwMbotAdmBZAzYmNk1qgHYE8kYsLFwjWojdFsCnINkDNgY3ZaA7MAGLsDm6LYEOB+1\nqZPk9FqqViCmxpKpUU1MzUdMzef0mFKbGnAAalQDzsUzYwAALEYyBgDAYnEl48OHD2vp0qWqqalR\na2trxPt+8pOf6GMf+5h+85vfmDZAAACcLmYyHhoa0tatW9XW1qb29nbt379fx48fn3Tf4OCgHnvs\nMV1xxRUpGSgAAE4VMxkfO3ZMlZWV8nq9ysvLU11dnQ4dOjTpvpaWFq1du1YXXnhhSgYKAIBTxdxN\nHQgEVFJyvsqPx+PRsWPHxt3zyiuvqK+vT5/5zGe0e/fuuN54xox8ud2uBIdrT9G2q2NqiKn5iKn5\niKn5sjWmMZOx0THknJyc0f89PDysBx54QA888EBCb9zf74wC904/F2cFJ8c0mbPCyXByTK1CTM3n\n9JhG+0Mj5jJ1SUmJ+vrOF6MPBAIqLi4e/frMmTP63e9+p7/9279VdXW1XnrpJa1bt45NXMAY9CQG\nEE3MmfH8+fPV1dUlv98vj8ej9vZ2PfTQQ6PXp0+frqNHj45+ffvtt2vDhg2aP39+akYMZCB6EgOI\nJubM2O12a/PmzVq7dq2WLVum2tpazZ07Vy0tLYYbuQCMR09iALFQmzpJTn/GYQWnxfTEide0cOHV\nGh4ennTN5XLpyJEXU17m0mkxtQNiaj6nxzSpZ8YAkkNPYgCxkIyBFIunJ3EwGNSJE6+xZA1kKbo2\nAWkQqSdxY+PX1djYoI6Odp061a3y8grV1tbJ52uW281/nkC24Jlxkpz+jMMKTo7pxHPGjY0N43ZZ\nh9XXrzN1l7WTY2oVYmo+p8eUZ8aADaV6lzVL30DmYB0MSINQKCSfb9O45ejrrrtep051G97f09Ot\nQKBvSrusjd5rxYpb1NCwhaVvwKb4LxNIA6OiH9/73l4VFBRocHBw0v1jd1knWkLT6L1aWlp09uw5\nCowANsUyNZBi0Zaj//Sn9wy/X1u7THl5eQmX0KTACJCZmBkDKRYI9EVcjg6F3p/0vXnz5svna55S\nCc1o75XM0jeA1GJmDKRYtKIfRt59d0DvvvvOlGa4FBgBMhPJGEixaEU/jPT0dOuVV34bc4ab6HuF\nC4wAsB+WqYE0mFj0o7S0TO+80x9x89YnPnG5yssr5PefNLwebYZrVGBkxYqb1dCwxaRPA8BszIyB\nNHC73Wpq2q7OzqM6cuRF/fzn/6E1a243vLe2dpkuvvjPIs5wFy9eEnWGO/G9OjuP6pFHHuFYE2Bj\nLp/P57PijYPBc1a8rekuuuhCx3wWu3ByTC+44ALNmDFDF1xwgRYt+oxOnx7QG2+8qTNnBlVRMVur\nV6+Rz9es3NzcMdff0MDAgFwul0ZGRvTWW2/K7z+pRYs+o9zcyH9Pj30vJ8fUKsTUfE6P6UUXXRjx\nGuUwk+T08m1WyLaYxjpHfN9992j37rZJ30+kZGa2xTQdiKn5nB5TymECNpafn685cy41TMTBYFAH\nDx4wfB3nhgHnIBkDNhbPuWEAmY9kDNhYtHPDxcUeFRYWpnlEAFKBZAzYWLRzw729PVqy5NMxS2QC\nsD/OOgA2N/bcsN//+rhr8ZTIBGB/zIwBmwufGz5w4KcqLS0zvIfNXEBmIxkDGWJgYCDihi02cwGZ\njWQMZAiaQADORTIGMgRNIADnYgMXkEGMmkDU1i4b/T6AzEQ5zCQ5vXybFYhpbLFKaE5ETM1HTM3n\n9JhGK4fJzBjIQOESmgCcgWfGAABYjGSMrBIMBnXixGucyQVgKyRjZIVQKKTGxgZVVS3QwoVXq6pq\nQcQykiRsAOlGMkZW8Pk2qbV1p/z+kxoeHh4tI+nzbRq9J5GEDQBmIhnD8YLBoDo62g2vjS0jGU/C\nBoBUIBnD8eLpCRxvwgaAVCAZw/HiKSMZT8IGgFQhGcPx4ikj6fGUqKys3PAe6j4DSDWKfiArRCsj\nGQqFtG3b1/XOO/2Gr6XuM4BUoxxmkpxevs0KqYxpMBjU6693SRpRZeUc5efnq7GxQa2tOyfdW1Aw\nXWvW/I18vma53Zn9dyu/p+YjpuZzekwphwlIozPgjo52nTrVrfLyCi1evFQHD/7E8P6iog9r48Yt\nGZ+IAdgf/8oga4SPLoX5/Se1Z8+/RLy/r69HgUAfNaABpBwbuJAVoh1dcrlcht9n4xaAdCEZIytE\nO7o0NDRs+H02bgFIF5IxskK0s8YFBRfpjjs+L6+3Ui6XS15vperr143uwE4Gda4BxINnxsgK4bPG\nRrumBwcHlZd3oTo7jyoQ6JPHU5L0jDgUCsnn2zRus1htbV1KdmYHg0HTxg3AGsyMkTU2bNikgoIC\nw2sdHU9LkubMudSUhJaOOtc0tgCcg2SMrPH2229FXC42s+Rluupc09gCcA6SMbJGPDWqzZCOOtc0\ntgCchWSMrBGrRrUkUzZbpSPp09gCcBaSMbKKz9es+vp143ZOr117p4aHh0179hpPY4pkpWuWDyA9\nSMbIKm63W01N29XZeVRHjryozs6jys3NVVvbLlOfvRolfbOOS0npSfgA0ieuRhGHDx9Wc3OzhoeH\ndeutt6q+vn7c9T179uiJJ56Qy+XSzJkztW3bNpWXG7ejC3NKMXCnFza3QjpiGj4OVFhYqCVLPi2/\n/+Ske7zeSnV2Hk0qsaXy2NH541OTO1FNPD7F76n5iKn5nB7TaI0iYibjoaEhLV26VHv27JHH49Gq\nVav08MMP67LLLhu955e//KWuuOIKTZs2TXv37tWvfvUrPfLII1EH5ZSAO/2XxwqpjOnE878eT4l6\ne3sM73W5XDpy5EXb16aOJ+Hze2o+Ymo+p8c0WjKOuUx97NgxVVZWyuv1Ki8vT3V1dTp06NC4e669\n9lpNmzZNknTllVeqr4/NI7CniceBIiViKXOevebn55t2PhqANWIm40AgoJKS8/8geTweBQKBiPc/\n+eSTWrRokTmjgyPYpSRktONARnj2CiBdYtblM1rFzsnJMbx33759evnll/X444/HfOMZM/Lldht3\ny8k00ZYeslkoFNL69eu1b98+nTx5UrNnz9by5cu1Y8eOmCUhUxHTP/zhjYjHgXJyclRWVqa+vj55\nvd64x5lJ+D01HzE1X7bGNOa/NCUlJeOWnQOBgIqLiyfdd+TIEX3rW9/S448/rry8vJhv3N/vjKIE\nTn/GkYzGxoZxtaC7urrU0tKis2fPqalpe8TXpSqmbneByssrDDdrVVTM1oEDP9XAwMDos9f+/rOm\nj8Eq/J7Pk+9JAAAPaklEQVSaj5iaz+kxTeqZ8fz589XV1SW/369z586pvb1d1dXV4+555ZVXtHnz\nZu3cuVMXX3xx8iNGxrNjhahYx4EuvvjPePYKwBIxZ8Zut1ubN2/W2rVrNTQ0pJUrV2ru3LlqaWnR\nvHnzdOONN+rBBx9UMBjUXXfdJUkqLS3Vt771rZQPHvYVT4UoK3Yph8/5Gh0HAgCrxHXOOBWcshTh\n9GWVqQoGg6qqWjCl87vpPGecLW0H+T01HzE1n9NjmtQyNTAVZlSISuUubI4DAbAT52wVhe1MdUk4\n3Kc3XJijvLxCtbV1hpWlAMAJWKZOktOXVcyQ6JJwc/P/UUtLy6Tv19evi7oLG5Hxe2o+Ymo+p8eU\nZWpYKpEl4WAwqKeeesrwGn16ATgVyRi2Egj0ye/3G16jTy8ApyIZw1Y8nhLNnj3b8Fq0WtF2KbkJ\nAFNBMs5Sdk1e+fn5Wr58ueE1o13Y4c1eVVULtHDh1aqqWqDGxgaFQqF0DBcATMHW1CwzsYWgHXcq\n79ixQ2fPnotrF3a4C1OY339y9Gs2ewHIFOymTlKm7f6bWC86zE47lcMxjbULO5nCItkm035PMwEx\nNZ/TY8puakiyZ73oaGLtwo6n5CYAZAKScRZxWvLyeEpUXl4R8fquXf/Is2MAGYFknEWiJa9oO5Xt\nKlrJzaGhIe3e3Safb1OaRwUAiSMZZxEz6kXbjc/XrM997n/L5XIZXrfj8jsATEQyzjI+X7Pq69fJ\n662Uy+WS11up+vp1UetF2/UYlPRBi88vfvFLGh4eNrye6PK7nT8rAOeyx1kWpI3b7VZT03Zt3Lgl\nZr3oTDgGJX2w/F5R4TXcVR3v8numfFYAzsTMOEvFUy86fIbX7z+p4eHh0TO8dnsOa8bye6Z8VgDO\nRDKGoUw7BjWV5fewTPusAJyH9TcYiucY1Jw5l6Z5VJElsvw+UaZ9VgDOw8wYhqIdgyou9qiwsDDN\nI4pPIu0aw5x25AtA5iEZw1C057C9vT1avHhR2hoypHqHsxOPfAHILCxTI6Lw89b29v06dWp8j+FT\np7rV2rpTw8PD2rbtH1Ly/unc4Rz+rPE0pwAAs5GMMc7Y5gx5eXn//V3jM7yS9L3v7VVj49dTMntM\nZ0emZJ45A0CySMaQZDwLLSoq0ssv/ybq6wYHT+v110/o4x+/3NTxxNrhvHHjlpQky/AzZwBIJ5Ix\nJBnPQv3+KC8YJ8f08bDDGUA2YQMXos5CYykomK7KykvMHZDY4Qwgu5CMEXUWGsvq1WtStlzMDmcA\n2YJlaozOQo1qO0dSUeHVsmWfTeluY3Y4A8gWJOMsNXbXdHgWOvaZcdi8efP17rsDo8lw8eIlWrv2\nTpWXV6R8dsoOZwDZgmScZSKd3W1s/Lok41nouXPnDJPhxISeKuxwBuB0JOMsE+vs7thZaPi6x1My\nLhnSbhAAzMUGriwST3ei/Px8eb2ztW3b11VVtUALF16tqqoF40pf0m4QAMxFMs4i8ZzdlaInWzPa\nDaa61jQAZBqScRaJ5+xurGT7+usn4kroRkKhkBobG1RVtUDXXnuVFi68Svfdd09amk0AgJ2RjLNI\nPGd3Y82epZwpF+MYO+MeGRlRb2+vdu9u05IlN5CQAWQ1krFFrFqq9fmaVV+/Tl5vpVwul7zeStXX\nrxs9uxtr9lxZecmUinFEm3G//PJvtGlTwxQ+DQA4A8k4zcYu1Rptjkq18Nndzs6jOnLkRXV2HlVT\n0/bRXdDxzJ5jJXQjgUCfursjF7v+yU/aeYYMIGvljIyMjFjxxm++edqKtzXdrFnTE/osjY0NhsU1\n6uvXmd4WcCqCwaBOnepWW9suPfvsgUlnjsceXUrknHEwGNTChVept7fX8LrL5dKRIy9qzpxLE44p\nYiOm5iOm5nN6TGfNmh7xGjPjFJq4FG3GTuR43ytRY2fsVVUL9Oyzz6impkadnb+aNHsOCxfjiKfg\nR7QZt0TzBwDZjWScApGWok+d6o57J3K8ydWsZW+j40y7d7fp0UfbTKuu1dT0oObNm294jeYPALIZ\nyTgFIp3TbWvbFXMncqLJ1YwCHKmcsY/ldrt14MDP9LnP/W+VlpbF/bwZAJyOZGyyaInt2WcPqKZm\nieG185uj4k+uZiXReIuBmMHtdmv79of0wgu/NtxABgDZiGScoFjLx7ES29q16yLuRE40uZqVROMp\nBmK2RJ43A4DTkYzjFM/ycTAY1J/+9CeVlpYZ/oyysgqVlZVHPFqUaHI1K4nGc5wJAJA6rA3GKVK3\no2nT8tTQsGVcF6NIyWtsYjNqCxg+IjQ4ODjptdOm5U9KrtH6ECeaRMPPbI1aKAIAUivrkvFUevBG\nWz7et2+fBgaC2rPnX0a/F06mBQXTdfZsMGJiM6MfsFlJNFwMZGwLRWbEAJAejin6ESuxDQwMaNOm\nDfrFLzrV03NK5eUVqqlZottuu0Nut0uVlXMiJp9XX/2tPv3p/ymjUOXm5srjKVFvb8+kaxUVXn33\nuz+Y9LMj9QO+4461uu66/xHxfV544deTZtPxfv5M4vSD/1YgpuYjpuZzekyjFf3I+JlxrEb34et7\n9z42bvk3fI529+42SVJBQYFWr75NW7c+MLqzN/zap5/eb5ggJam0tDRiVane3h596EPTJEknTrw2\nmigjLXmHQu+rosIrv//kpJ9VXu6N+gzYaNkbAJAZMj4ZR0psktTUtH3S9UgGBwfV1rZLubm5o2Up\n43nt8uXL9eMf7zdMoGVl5dq16x918OCB0T8UFi9eqoMHf2L4sw4ePKjFi5eOW/IOYyMVADhXRu+m\n/uBZ7n7Dax0d7Xr77bciPuuN5LvffUwDAwNRnxNLH8xU6+vXqaWlJeJO5KKiIu3e3TbuzPCePf8S\nsWHCB0ef7ky4CQMAILNl9Mw4EOiT32+c2Px+v1555bcRjwpFcubMoBobN+irX90Q8bW5uS7t3fuE\nPv7xT8jtdhtuolq8eEnEGbDL5dLQ0NCk75eVVai8vIKNVACQZeKaGR8+fFhLly5VTU2NWltbJ10/\nd+6c7r77btXU1OjWW29Vd3diCXCqCgsL5XK5DK+5XLmaM+fSiOdwo+ns7FRhYWHE15aXf9DXN8yo\nLeEXv/gl9fScMnz90NCw4feNjj6RiAHA+WIm46GhIW3dulVtbW1qb2/X/v37dfz48XH3PPHEEyos\nLNTBgwd1xx13aMeOHSkb8FgDAwOGM0zpg3G///77UTsFRdLX16OBgYGEC2GMTaDRCnJ4vV59/vNr\nWYoGAEiKIxkfO3ZMlZWV8nq9ysvLU11dnQ4dOjTunueee0633HKLJGnp0qV64YUXIu4+NpPHU6KK\nCq/htYqK2fJ4SuTzNY97BuvxlMb8ueHqVRNfm0jSjF7Vqk7f+MbDhlW4AADZJ+a//oFAQCUl54/U\neDweHTt2bNI9paUfJDm3263p06erv79fM2fOjPhzZ8zIl9ttvMQcv+lauXKFWlpaJl1ZufIWVVZ6\nJEm7dv2zgsGgent7VVRUpGuuuUZdXV0Rf+qKFTcbvra0tNRwRhzp7Ng//dP/1bRpedq3b5/8fr+8\nXq+WL1+uHTt2/HfinT76Phgv2nk8TA0xNR8xNV+2xjRmMjaa4ebk5CR8z0T9/ea05Wto2KKzZ89N\nqkDV0LBl0uHxwsJijYxIS5bUGh5ZKiiYrjVr/ibia8+cGdKZM+O/H+uQ+qZN9+urX/3auM1Y/f1n\nk/jEzuf0g/9WIKbmI6bmc3pMkyr6UVJSor6+8w0KAoGAiouLJ93T29urkpIP+vGePn1aH/7wh5MY\ncvymUsZx4u7n0tIyXXddlZqbH1RhYaHpY6QgBwAgmpjJeP78+erq6pLf75fH41F7e7seeuihcfdU\nV1frRz/6ka666io988wzuvbaa2POjM2WSMKjDjMAwE5iJmO3263Nmzdr7dq1Ghoa0sqVKzV37ly1\ntLRo3rx5uvHGG7Vq1Srde++9qqmpUVFRkb75zW+mY+xJY8YKALADxzSKsIrTn3FYgZiaj5iaj5ia\nz+kxjfbMOKPLYQIA4AQkYwAALEYyBgDAYiRjAAAsRjIGAMBiJGMAACxGMgYAwGIkYwAALEYyBgDA\nYiRjAAAsRjIGAMBiltWmBgAAH2BmDACAxUjGAABYjGQMAIDFSMYAAFiMZAwAgMVIxgAAWIxkbKJv\nf/vb+tjHPqY//vGPVg8l423fvl1/8Rd/oZtuuklf+tKXNDAwYPWQMtbhw4e1dOlS1dTUqLW11erh\nZLze3l7dfvvtqq2tVV1dnb7zne9YPSTHGBoa0s0336w777zT6qGkHcnYJL29vTpy5IjKysqsHooj\nXHfdddq/f79+/OMf65JLLtGuXbusHlJGGhoa0tatW9XW1qb29nbt379fx48ft3pYGc3lcum+++5T\nR0eHvv/972vv3r3E1CT/+q//qo985CNWD8MSJGOTPPDAA7r33nuVk5Nj9VAc4frrr5fb7ZYkXXnl\nlerr67N4RJnp2LFjqqyslNfrVV5enurq6nTo0CGrh5XRiouLdfnll0uSCgoKdOmllyoQCFg8qszX\n19en559/XqtWrbJ6KJYgGZvg0KFDKi4u1p//+Z9bPRRH+vd//3ctWrTI6mFkpEAgoJKSktGvPR4P\nicNE3d3devXVV3XFFVdYPZSMt23bNt17773Kzc3OtOS2egCZ4o477tBbb7016ft33323du3apd27\nd1swqswWLaaLFy+WJO3cuVMul0t/+Zd/me7hOYJRtVtWb8xx5swZfeUrX9HGjRtVUFBg9XAy2k9/\n+lPNnDlT8+bN09GjR60ejiVIxnF69NFHDb//X//1X+ru7tby5cslfbDUsmLFCj3xxBOaNWtWGkeY\neSLFNOxHP/qRnn/+eT366KMkkCkqKSkZt8QfCARUXFxs4Yic4f3339dXvvIV3XTTTVqyZInVw8l4\nv/71r/Xcc8/p8OHDeu+99zQ4OKj169drx44dVg8tbWgUYbLq6mo9+eSTmjlzptVDyWiHDx/WN77x\nDT3++OPEMgmhUEhLly7Vo48+Ko/Ho1WrVumhhx7S3LlzrR5axhoZGVFDQ4OKioq0adMmq4fjOEeP\nHtXu3buzbtMmM2PY0v33369z587pc5/7nCTpiiuu0NatWy0eVeZxu93avHmz1q5dq6GhIa1cuZJE\nnKQXX3xR+/bt00c/+tHRFbF77rlHN9xwg8UjQyZjZgwAgMWyc9saAAA2QjIGAMBiJGMAACxGMgYA\nwGIkYwAALEYyBgDAYiRjAAAsRjIGAMBi/x+crcz0WAKXSAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fbc4a2e4400>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(figsize=(8, 6))\n",
"\n",
"ax.scatter(x, p, c='k');"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"with pm.Model() as model:\n",
" mu_alpha = pm.Normal('mu_alpha', 0., 1.)\n",
" sigma_alpha = pm.HalfCauchy('sigma_alpha', 5.)\n",
" delta_alpha = pm.Normal('delta_alpha', 0., 1., shape=N)\n",
" alpha = pm.Deterministic('alpha', mu_alpha + delta_alpha * sigma_alpha)\n",
" \n",
" beta = pm.Normal('beta', 0., 1.)\n",
" \n",
" p_ = pm.Deterministic('p', tt.nnet.sigmoid(alpha + beta * x))\n",
" \n",
" y_obs = pm.Bernoulli('y_obs', p_, observed=y)\n",
" \n",
" gamma = pm.Normal('gamma', 0., 1.)\n",
" err = pm.HalfCauchy('err', 5.)\n",
" \n",
" z_obs = pm.Normal('z_obs', gamma * p_, err, observed=z)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def random_init(model):\n",
" def trans_sample(v): \n",
" return v.distribution.transform_used.backward(v.distribution.dist.random()).eval()\n",
"\n",
" var_dict = {v.name:v.distribution.random() for v in model.vars \n",
" if not v.name.endswith('_')}\n",
" trans_var_dict = {v.name: trans_sample(v) for v in model.vars \n",
" if v.name.endswith('_')}\n",
" var_dict.update(trans_var_dict)\n",
"\n",
" return var_dict"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"start = random_init(model)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/5000 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'err_log_': array(2.522191010632724), 'gamma': array(-0.12380656361410802), 'beta': array(-0.07410631818436134), 'sigma_alpha_log_': array(2.244999072752208), 'delta_alpha': array([-0.12287236, 1.46271563, 0.12381445, -0.62970604, -0.58242381,\n",
" 0.52726837, 1.20746729, -0.02682073, 0.33071269, 1.33914336,\n",
" -0.17260608, -0.20703458, 1.46545484, 1.38079962, -1.3917848 ,\n",
" 0.06914354, -1.58789709, 0.50223894, -2.04124726, 0.85020358,\n",
" 0.05099627, 0.26315539, -0.07137581, -1.52361391, 1.65414547,\n",
" 0.56135911, -0.11590509, 0.19275929, 0.27496712, -1.90173521,\n",
" 0.37577714, -0.76327879, 0.74859593, -0.264977 , -1.59406705,\n",
" 0.32649373, 1.41511622, 1.37734621, 0.5713924 , -0.09579313,\n",
" 0.86376445, 1.73045437, 0.7617386 , -2.78447456, -0.26491465,\n",
" 0.72282762, 0.69201938, 1.42552949, 0.07620791, 0.81409896,\n",
" 1.17563015, -0.25287858, -0.18538296, 1.90986785, -0.8079035 ,\n",
" -0.69182554, 0.20401296, 0.42092393, 1.20496163, -0.97012758,\n",
" -0.02364161, 0.92291018, -0.23242475, 1.23002823, -1.10833106,\n",
" -0.46669149, 0.12525846, 0.67000778, 0.27925912, 1.19090099,\n",
" 0.35150114, -0.04048651, 1.21721614, -0.05467444, -1.03115495,\n",
" 1.56030819, -1.71277029, -0.59345612, -0.57164705, 0.01194078,\n",
" 0.54790912, 0.54498176, -0.39754642, -0.12664766, -0.713671 ,\n",
" -1.40802654, -0.57874976, 1.53061462, 1.35733744, 0.82947721,\n",
" -0.18843756, 0.4113656 , -0.52784237, 0.74741859, 1.11813511,\n",
" -0.41785859, -1.42315816, 0.35994175, 1.02619037, 0.03048772]), 'mu_alpha': array(0.38310875436377867)}\n"
]
},
{
"ename": "FloatingPointError",
"evalue": "NaN occurred in ADVI optimization.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-10-dadacab31cee>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0madvi_fit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/home/jovyan/pymc3/pymc3/variational/advi.py\u001b[0m in \u001b[0;36madvi\u001b[0;34m(vars, start, model, n, accurate_elbo, optimizer, learning_rate, epsilon, mode, random_seed)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'NaN occurred in ADVI optimization.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 155\u001b[0m \u001b[0melbos\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFloatingPointError\u001b[0m: NaN occurred in ADVI optimization."
]
}
],
"source": [
"with model:\n",
" advi_fit = pm.advi(start=start)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"widgets": {
"state": {},
"version": "1.1.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment