Skip to content

Instantly share code, notes, and snippets.

@canyon289
Last active March 3, 2021 05:41
Show Gist options
  • Save canyon289/97c94048d51ad3af80474f3adfa9c611 to your computer and use it in GitHub Desktop.
Save canyon289/97c94048d51ad3af80474f3adfa9c611 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "missing-community",
"metadata": {},
"source": [
"# Trying to Group Level and individual level out of sample predictions\n",
"But running into two questions\n",
"1. Is adding RVs after sampling, for posterior predictive sampling, an abuse of the api?\n",
"2. Why do unconditioned RVs affect the logp for the model? (and alter sampling)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "alpine-light",
"metadata": {},
"outputs": [],
"source": [
"import pymc3 as pm\n",
"import matplotlib.pyplot as plt\n",
"import arviz as az\n",
"import pandas as pd\n",
"import numpy as np\n",
"from scipy import stats\n",
"import theano"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "economic-great",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'3.11.0'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.__version__"
]
},
{
"cell_type": "markdown",
"id": "defensive-ozone",
"metadata": {},
"source": [
"# Make some data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "architectural-cameroon",
"metadata": {
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Finished location 0\n",
"Finished location 1\n",
"Finished location 2\n",
"Finished location 3\n",
"Finished location 4\n",
"Finished location 5\n"
]
}
],
"source": [
"def salad_generator(hyperprior_beta_mean=5, hyperprior_beta_sd=.3, sigma=50, days_per_location=[5, 3, 15, 10, 3, 5]):\n",
" \"\"\"Generate noisy salad data\"\"\"\n",
" beta_hyperprior = stats.norm(hyperprior_beta_mean, hyperprior_beta_sd)\n",
" \n",
" # Generate demands days per restaurant\n",
" df = pd.DataFrame()\n",
" for i, days in enumerate(days_per_location):\n",
" np.random.seed(0)\n",
"\n",
" num_customers = stats.randint(30, 100).rvs(days)\n",
" sales_location = beta_hyperprior.rvs()*num_customers + stats.norm(0, sigma).rvs(num_customers.shape)\n",
"\n",
" location_df = pd.DataFrame({\"customers\":num_customers, \"sales\":sales_location})\n",
" location_df[\"location\"] = i\n",
" location_df.sort_values(by=\"customers\", ascending=True)\n",
" df = pd.concat([df, location_df])\n",
"\n",
" print(f\"Finished location {i}\")\n",
" df.reset_index(inplace=True, drop=True)\n",
" return df\n",
"hierarchical_salad_df = salad_generator()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "handy-brush",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfQAAAGeCAYAAABiufO2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyDklEQVR4nO3dfZhdZX3v//dnkmESCZCQzMmBTCCWRC14hUgjYmOphR4rqIHLAOITkXLKscWn+pBgz/lV9GiFeCzgj1/RFCqBqhCJmshBxSag0goYIAwiCuEhZsZAhpBAIsl0kvn+/lj3hp1hnjNrP6z9eV3Xvmbtez3s7177Tr573eve962IwMzMzOpbU7UDMDMzswPnhG5mZlYATuhmZmYF4IRuZmZWAE7oZmZmBeCEbmZmVgBO6GaApDsk/fcKv+YHJN05RseaKOn7kp6T9O0R7HeUpF2Sxo1FHH2OPWbvbxivNUtSSBo/zO2vk/T5vOMyqyQndCu0lKi3S2qpdiw5OwuYDkyNiLOHu1NE/DYiJkXEvvxCG5qkSyT9azVjGEg1vuyZjYYTuhWWpFnAnwABLKxuNLk7GngkIvZWOxAzqw4ndCuy84C7gOuAxcPY/hhJ90h6XtJqSYcDSHqzpI7yDSU9KenP0/IlklZKul7STkkPSZpftu1MSd+R1CVpm6Sr+hzr/6RWhCcknTZQcJL+MF0t7kivsTCVfxb4e+Bdqfn8gn72PVHS+vTenpb0j6l8v6bqdPzPS/qPdKzvS5oq6Rtp31+kL0r9NnMPdjUr6UpJm9Nx7pX0J6n8rcDflcX/QCo/TNK1krZI6kxxjUvrxqXz9oykx4G3DXTe0vavk3Rf+nxuAiaUrZsi6Zb0+WxPy21p3RfIvhRelWK7arD3YlZNTuhWZOcB30iPv5A0fRjb/yVwBLAX+MoIXmshcCMwGVgDlP7jHwfcAmwCZgEz0nYlbwB+A0wDlgHXSlLfg0tqBr4P3Ab8F+DDwDckvToiPgP8A3BTaj6/tp/4rgSujIhDgWOAlYO8l3OB96dYjwF+DnwdOBx4GPjM4KdiQL8A5qXjfBP4tqQJEfHDPvEfn7a/juxzmA28DngLUPqy8FfA21P5fLJbDv2SdBDwPeCG9NrfBhaVbdKU3t/RwFHAbtLnFxH/E/gZ8KEU24cGey8jPB9mY8oJ3QpJ0pvI/oNeGRH3Ao8B7xlitxsi4pcR8Xvg/wHOGUFnsTsj4tZ0L/oGoJSUTgSOBD4VEb+PiD0RUd5RbFNE/HPabwXZl4n+vnicBEwCLo2I/4yIdWRfFN49zPh6gNmSpkXEroi4a5Btvx4Rj0XEc8APgMci4t9Sc/63yZLoiEXEv0bEtojYGxFfBlqAV/e3bfrydTrwsXTetgKXk33ZADgHuCIiNkfEs8AXB3npk4DmtH1PRNxMlpBLcW2LiFUR8UJE7AS+APzpWL0Xs0pxQreiWgzcFhHPpOffZOhm981ly5vIksC0Yb7eU2XLLwATUlP0TLKkPdC97Rf3i4gX0uKkfrY7EtgcEb19YpwxzPguAF4F/Do1m799kG2fLlve3c/z/uIbkqRPSnpYWU/8HcBhDHx+jyY7/1vSLYYdwNfIWicgnY+y7TcN8tJHAp2x/0xUL24v6RWSviZpk6TngZ8Ckwf7MjfC92JWEcP6iYdZPZE0kewKbpykUsJsIftP+viIeGCAXWeWLR9FdlX7DPB74BVlxx8HtA4znM3AUZLGH2CHtd8BMyU1lSX1o4BHhrNzRDwKvFtSE/BO4GZJUw8gHsjOC2Tn5vm0/F/72zDdY14CnAo8FBG9krYDpdsLfad93Ax0A9MGOG9bePnnNZAtwAxJKkvqR5G12gB8guzq+g0R8ZSkecD9A8U2jPdiVhW+QrciOhPYBxxLdp9zHvCHZPdCzxtkv/dJOlbSK4DPATenpvBHyK6435buZf8vsi8Iw3EPWUK5VNLBkiZIWjDyt8TdZFf+SyQ1S3oz8A72vx8/IEnvk9SavgzsSMW9g+wypIjoAjrJzts4SX9Jds+9P4eQ3Q/vAsZL+nvg0LL1TwOz0hcOImILWX+BL0s6VFKTpGMklZrCVwIfkdQmaQpw8SCh/jy99kfSuXsn2a2Q8th2AzuUdYTs20fgaeAPRvBezKrCCd2KaDHZfeDfRsRTpQdZR6f3auDBR24g64j1FFkv6I8ApHvJfwNcQ5bAfg909H+I/aUvBO8g69j127Tfu0b6hiLiP9NxTiNrNfgn4LyI+PUwD/FW4CFJu8g6yJ0bEbtHGkc//gr4FLANOA74jwG2+xHwQ7IvR5uAPezfZF4aDGebpPvS8nnAQcCvgO3AzWR9DAD+OR3zAeA+4DsDBZjO3TuBDwDPkp3/8u2vACaSnde7UpzlrgTOSj3gvzKM92JWFdr/tpKZmZnVI1+hm5mZFYATupmZWQE4oZuZmRWAE7qZmVkBOKGbAZIWSHo0jdd9ZrXj6UvSn0j6TbXjgCHHax+T6VgHe408qZ9x+wfZtmZniLPG5IRulvkccFUar/t7lXzhgZJIeVKLiJ9FRM0PLVor07GaNSIndLPM0cBDo9lxkN+114SRxlfr78fM+ueEbg1P0mNkI4F9PzUXt0g6UtIaSc9K2ijpr8q2v0TSzZL+NY39/YF+jvk2Sfen6TU3S7rkAGPc7yo+xbdK2ZSfT0j6yGDxKZs+9edpXPQtkq5SNgtZaZ+QdJGkR4FHU9kZkjak9/CYsmlOS46W9O/KpiO9TdK0tE/f6VgPl/R1Sb9LA7N8L5VP0QBTlg7jXFwi6dvp/e2U9KCkV0n6tKSt6Xy/pc+5GuiznCjpuhTDr4DX93mtAc+zWa1xQreGFxHHkI3i9o7UXNxNNqRqB9nEHmcB/yDplLLdziAbuWwy2fSsff2ebKSzyWRzdf/1WN2bT8Ojfp9slLQZZGOKf0zSXwwS3z7gb8kmEHlj2udv+hz6TLLpXI+VdCJwPdkocJOBk4Eny7Z9D3A+2WQpBwGfHCDcG8jGej8ubXt5Kh9wytJhekc69hSycdd/lI45g+z2ydfKth3ss/wM2XC1xwB/QdkEPsM8z2Y1wwndrA9JM4EFwNI03ekGsmFfy8eB/3lEfC8ievsbQjUi7oiIB9P6duBbDD4l55Hp6vnFB/CmAbZ9PdAaEZ9LU6k+TjYU6rll2+wXX0TcGxF3pek+nyRLeH3j+WJEPJvezwXAv0TEj9MxOvsMM/v1iHgkbbuSbLz8/Ug6gmyo2g9GxPY0delP0vkZ8ZSlffwsIn5UNqVrK9nUsj1kCXyWpMnD+CzPAb6Q3vdm4CtlrzGc82xWM3yvzOzljgSeTYmmZBMwv+z5oGN3S3oDcCnwWrIr2BZeGq+8P7+LiP2anCXdMcC2R5O+AJSVjSObfKbf+CS9CvhHsvfwCrJ/+/f2OW75PjOBWweJt+90sf1NqTqT7Dxu77tC2QQ4l5ONMT8lFR8iadwwO9T1ndL1mbL9Sl+wJjH0ZznYNKzDOc9mNcNX6GYv9zvgcEmHlJUdRTYxS8lQkyB8E1gDzIyIw4CvMnbTa24GnoiIyWWPQyLi9EHiuxr4NTAnIg4F/q6feMr32czAM6eNJM7DJU3uZ135lKWHkjXp009MB2qoz3KwaViHc57NaoYTulkfqen1P4AvKpvudC5ZE/RIfnN8CNmV4Z50P/o9YxjiPcBOSUtTp65xkl4r6fWD7HMI2ZzluyS9BvjrIV7jWuB8Sacqm7p0Rtpv2NIUqD8A/il1gmuWVErcQ01ZOiaG8VmuBD6d4msDPly2+2jOs1nVOKGb9e/dwCyyK7zvAp+JiH8bwf5/A3xO0k7g78kSx5hITctvJ7tv/QTZtJ/XAIcNstsnyb5U7CS7D3zTEK9xD1mnt8uB54CfkDVBj9T7gR6y1oGtwMdS+RUMPmXpWBrss/wsWTP7E2Tzr99Q2mmU59msajx9qpmZWQH4Ct3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrADGVzuAAzFt2rSYNWtWtcOwGnDvvfc+ExGtlXgt1zsrqVS9c52zksHqXF0n9FmzZrF+/fpqh2E1QNKmSr2W652VVKreuc5ZyWB1zk3uZmZmBeCEbmZmVgBO6GZmZgXghG5mZlYATuhmZmZjYNuubh7YvINtu7qr8vp13cvdzMysFqze0MnSVe00NzXR09vLskVzWThvRkVj8BW6mZnZAdi2q5ulq9rZ09PLzu697OnpZcmq9opfqTuhm5mZHYCO7btpbto/nTY3NdGxfXdF43BCNzMzOwBtUybS09u7X1lPby9tUyZWNA4ndDMzswMwdVILyxbNZUJzE4e0jGdCcxPLFs1l6qSWisbhTnFmZmYHaOG8GSyYPY2O7btpmzKx4skcnNDNzMzGxNRJLVVJ5CVucjczMysAJ3QzM7MCcEI3MzMrACd0MzOzAnBCNzMzKwAndDMzsxozmole/LM1MzOzGjLaiV58hW5mZlYjDmSiFyd0M8tdteeJNqsXBzLRS65N7pKeBHYC+4C9ETFf0uHATcAs4EngnIjYLknAlcDpwAvAByLivjzjM7P8DdZ8uG1Xd1WHyjSrNQcy0UslrtD/LCLmRcT89PxiYG1EzAHWpucApwFz0uNC4OoKxGZmORqs+XD1hk4WXLaO911zNwsuW8eaDZ3VDtes6g5kopdqdIo7A3hzWl4B3AEsTeXXR0QAd0maLOmIiNhShRitBvlqrv6Umg/38NIVR3NTEw/97vkXE31p3ZJV7SyYPc2frTW80U70kndCD+A2SQF8LSKWA9PLkvRTwPS0PAPYXLZvRyrbL6FLupDsCp6jjjoqx9Ctlqze0MmSmx9gnJrYF7186azjh9Xrc6y43o3OQM2HEP0m+o7tu53QE9e5xjaaiV7ybnJ/U0ScQNacfpGkk8tXpqvxGMkBI2J5RMyPiPmtra1jGKrVqm27uvnEyg107w1e6NlH997g4ys3VLSDlevd6AzUfHjckYeN+j5ho3Cds5HK9Qo9IjrT362SvgucCDxdakqXdASwNW3eCcws270tlVmDe+h3z7N3///72dublZ/8Kv9HV+sGaj5ctmguS/p0lvPVudno5ZbQJR0MNEXEzrT8FuBzwBpgMXBp+rs67bIG+JCkG4E3AM/5/rllBmrEGVHjjlVRf82Ho71PaGb9y/MKfTrw3ezXaIwHvhkRP5T0C2ClpAuATcA5aftbyX6ytpHsZ2vn5xib1ZHjjjyM5nGiZ99LCbx5nDjuyMOqGJWNhdHcJzSz/uWW0CPiceD4fsq3Aaf2Ux7ARXnFY/Vr6qQWvnz28Xzq5nbGNYl9vcGXznLzrJmNXC39WmasY/FY7lYX3DxrZgdqtGOk10ssHvrV6sbUSS0cP3Oyk7mZjdiBjJFeL7E4oZuZWeEdyBjp9RKLE7qZmRXegYyRXo1YRjOhkRO6mZkV3oGMkV7pWEY7z4E7xZmZUVu9ny0ftdS5dqBYyu+vj3SeAyd0M2t4tdT72fJVS2Mf9BfLQBMaDWeeAze5m1lDq6Xez2a1Ph+62YiMpjOI2WjVUu9ns3qbD91sQG76tEqrpd7PZjD6e/2+Qm8wtXz166ZPq4Za6v1sVjKagbR8hd5Aav3q90A6g5gdiFrq/Ww2Wk7oDeJAfgpRKW76tGqqpd7PZqPhJvcGUQ8df9z0aWZ91fJtwlrjK/QGUS9Xv276NLOSWr9NOFbGalAjJ/QGUbr6XdLnH0ctJkw3fZpZPdwmHAtj+aXFCb2B+OrXzOpFI3SSHesvLU7oDcZXv2ZWD+rlNuGBGOsvLbl3ipM0TtL9km5Jz18p6W5JGyXdJOmgVN6Snm9M62flHZuZmdWmRugkO9ZfWipxhf5R4GHg0PT8MuDyiLhR0leBC4Cr09/tETFb0rlpu3dVID4zM6tBB3KbsB5mzxvrvk25JnRJbcDbgC8AH5ck4BTgPWmTFcAlZAn9jLQMcDNwlSRFROQZo5mZ1a7R3Casp97xY9m3Ke8m9yuAJfDiDYKpwI6I2JuedwClszwD2AyQ1j+Xtt+PpAslrZe0vqurK8fQzV7iemeV5jo3OvU4hPRohnntT24JXdLbga0Rce9YHjcilkfE/IiY39raOpaHNhuQ651Vmuvc6NTDIFp5yfMKfQGwUNKTwI1kTe1XApMllZr624DOtNwJzARI6w8DtuUYn5mZ1ajRjhA3UEezgw8aV/gR53JL6BHx6Yhoi4hZwLnAuoh4L3A7cFbabDGwOi2vSc9J69f5/rmZWeNZvaGTBZet433X3M2Cy9axZkPn0Dsl/fWOP+eP2nj7VXeO6nj1pBq/Q18K3Cjp88D9wLWp/FrgBkkbgWfJvgSYmVkDGYvBVso7mh180DjeftWdhR9xDiqU0CPiDuCOtPw4cGI/2+wBzq5EPGZmVpvGarCVUu/4BzbvKPyIcyWebc3MzGrGWA+20ggjzpU4oZuZWc0Y6xHiGmHEuRKP5W5mZjVlrCeSapSJqZzQzcxsVPIcXnWgEeJG+5qNMDGVE7qZmY1YNYZXrachXavB99DNzGxEqjG8aj0O6VppTuhmZjYi1RhetZGHdB0uJ3QzMxuRavwUrJF+fjZaTuhmZjYi1fgpWCP9/Gy03CnOzMxGrBo/BWuUn5+NlhO6mZmNSjV+CtYIPz8bLTe5m5mZFYATupmZWQE4oRuQ/cbzgc07/JtOM7M65Xvo5tGXzMwKwFfoDc6jL5mZFYMTeoPz6EtmZsXghN7gPPqSmVkx5JbQJU2QdI+kByQ9JOmzqfyVku6WtFHSTZIOSuUt6fnGtH5WXrHZSzz6kplZMeTZKa4bOCUidklqBu6U9APg48DlEXGjpK8CFwBXp7/bI2K2pHOBy4B35RifJR59ycz6k+d85zb2ckvoERHArvS0OT0COAV4TypfAVxCltDPSMsANwNXSVI6juXMoy+ZWTn/+qX+5HoPXdI4SRuArcCPgceAHRGxN23SAZRqyAxgM0Ba/xwwtZ9jXihpvaT1XV1deYZv9iLXO6u0atY5//qlPuWa0CNiX0TMA9qAE4HXjMExl0fE/IiY39raeqCHMxsW1zurtGrWOf/6pT5VpJd7ROwAbgfeCEyWVGrqbwM603InMBMgrT8M2FaJ+MzM7CX+9Ut9yrOXe6ukyWl5IvDfgIfJEvtZabPFwOq0vCY9J61f5/vnZmaV51+/1Kc8e7kfAayQNI7si8PKiLhF0q+AGyV9HrgfuDZtfy1wg6SNwLPAuTnGZmZmg/CvX+pPnr3c24HX9VP+ONn99L7le4Cz84rHzMxGxr9+qS8eKc7MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzAogt4Quaaak2yX9StJDkj6ayg+X9GNJj6a/U1K5JH1F0kZJ7ZJOyCs2MzOzosnzCn0v8ImIOBY4CbhI0rHAxcDaiJgDrE3PAU4D5qTHhcDVOcZmZmZWKLkl9IjYEhH3peWdwMPADOAMYEXabAVwZlo+A7g+MncBkyUdkVd8ZmZmRVKRe+iSZgGvA+4GpkfElrTqKWB6Wp4BbC7brSOV9T3WhZLWS1rf1dWVX9BmZVzvrNJc52ykck/okiYBq4CPRcTz5esiIoAYyfEiYnlEzI+I+a2trWMYqdnAXO+s0lznbKRyTeiSmsmS+Tci4jup+OlSU3r6uzWVdwIzy3ZvS2VmZmY2hDx7uQu4Fng4Iv6xbNUaYHFaXgysLis/L/V2Pwl4rqxp3szMhmHbrm4e2LyDbbu6qx2KVdj4HI+9AHg/8KCkDans74BLgZWSLgA2AeekdbcCpwMbgReA83OMzcyscFZv6GTpqnaam5ro6e1l2aK5LJz3sq5IVlC5JfSIuBPQAKtP7Wf7AC7KKx4zsyLbtqubpava2dPTyx56AViyqp0Fs6cxdVJLlaOzSvBIcWZmBdCxfTfNTfv/l97c1ETH9t1VisgqzQndzKwA2qZMpKe3d7+ynt5e2qZMrFJEVmlO6GZmBTB1UgvLFs1lQnMTh7SMZ0JzE8sWzXVzewPJs1OcmZlV0MJ5M1gwexod23fTNmWik3mDcUI3MyuQqZNanMgblJvczczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCiC3hC7pXyRtlfTLsrLDJf1Y0qPp75RULklfkbRRUrukE/KKy8zMrIjyvEK/Dnhrn7KLgbURMQdYm54DnAbMSY8LgatzjMvMzKxwckvoEfFT4Nk+xWcAK9LyCuDMsvLrI3MXMFnSEXnFZmZmVjSVvoc+PSK2pOWngOlpeQawuWy7jlT2MpIulLRe0vqurq78IjUr43pnleY6ZyNVtU5xERFAjGK/5RExPyLmt7a25hCZ2cu53lmluc7ZSFU6oT9dakpPf7em8k5gZtl2banMzMzMhqHSCX0NsDgtLwZWl5Wfl3q7nwQ8V9Y0b2ZmZkMYn9eBJX0LeDMwTVIH8BngUmClpAuATcA5afNbgdOBjcALwPl5xWVmVgTbdnXTsX03bVMmMnVSS7XDsRqQW0KPiHcPsOrUfrYN4KK8YjEzK5LVGzpZuqqd5qYmenp7WbZoLgvn9duP2BqIR4ozM6sj23Z1s3RVO3t6etnZvZc9Pb0sWdXOtl3d1Q7NqswJ3cysjnRs301z0/7/dTc3NdGxfXeVIrJa4YRuZlZH2qZMpKe3d7+ynt5e2qZMrFJEViuc0M3M6sjUSS0sWzSXCc1NHNIyngnNTSxbNNcd4yy/TnFmZpaPhfNmsGD2NPdyt/04oZuZ1aGpk1qcyG0/bnI3MzMrACd0MzOzAlA2pkt9ktRFNuLcQKYBz1QonNFyjGPj6IioyAwWrncVUw8xVqTeuc5VVK3HOWCdq+uEPhRJ6yNifrXjGIxjLJ56OF+OsVjq4VzVQ4xQP3H2x03uZmZmBeCEbmZmVgBFT+jLqx3AMDjG4qmH8+UYi6UezlU9xAj1E+fLFPoeupmZWaMo+hW6mZlZQ3BCNzMzKwAndDMzswJwQjczMysAJ3QzM7MCcEI3MzMrACd0MzOzAnBCNzMzKwAndDMzswJwQjczMysAJ3QzM7MCcEI3MzMrACd0MzOzAnBCNzMzKwAndDMzswJwQjczMysAJ3QzM7MCcEI3MzMrACd0MzOzAnBCNzMzKwAndDMzswJwQjczMysAJ3QzM7MCcEI3MzMrgPHVDuBATJs2LWbNmlXtMKwG3Hvvvc9ERGslXsv1zkoqVe9c56xksDpX1wl91qxZrF+/vtphWA2QtKlSr+V6ZyWVqneuc1YyWJ1zk7uZmVkBOKFb3di2q5sHNu9g267uaodiZlZz6rrJ3RrH6g2dLF3VTnNTEz29vSxbNJeF82ZUOywzs5rhK3Sredt2dbN0VTt7enrZ2b2XPT29LFnV7it1M7MyTuhW8zq276a5af+q2tzURMf23VWKyMys9jihW81rmzKRnt7e/cp6entpmzKxShGZmdUeJ3SreVMntbBs0VwmNDdxSMt4JjQ3sWzRXKZOaql2aGZmNcOd4qwuLJw3gwWzp9GxfTdtUyY6mZuZ9eGEbnVj6qQWJ3IzswG4yd3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M0sd9t2dfPA5h1s29Vd7VDMCivXyVkkPQnsBPYBeyNivqTDgZuAWcCTwDkRsV2SgCuB04EXgA9ExH15xmf1Zduubs+2VodWb+hk6ap2mpua6OntZdmiuSycNwPwZ2o2liox29qfRcQzZc8vBtZGxKWSLk7PlwKnAXPS4w3A1emv2aBJwWrXtl3dLF3Vzp6eXvbQC8CSVe0smD2NOzc+48/UbAxVo8n9DGBFWl4BnFlWfn1k7gImSzqiCvFZjSlPCju797Knp5clq9rdfFsHOrbvprlp//9mmpuaeOh3z/szNRvEaG5T5Z3QA7hN0r2SLkxl0yNiS1p+CpielmcAm8v27Uhl+5F0oaT1ktZ3dXXlFbfVkIGSQsf23RWLwfVudNqmTKSnt3e/sux5VP0zrXWuc41r9YZO/vjStbx7+V388aVrWbOhc1j75Z3Q3xQRJ5A1p18k6eTylRERZEl/2CJieUTMj4j5ra2tYxiq1Yq+30wHSgptUyZWLCbXu9GZOqmFZYvmMqG5iUNaxjOhuYlli+Zy3JGHVf0zrXWuc41p265uPrFyA917gxd69tG9N/j4yg3DulLP9R56RHSmv1slfRc4EXha0hERsSU1qW9Nm3cCM8t2b0tlVudG0vFpoHvlyxbNZUmfcneiqg8L581gwexpL6sD/kzNXu6h3z3P3v2/67K3Nys/+VWDf7HLLaFLOhhoioidafktwOeANcBi4NL0d3XaZQ3wIUk3knWGe66sad7q1Eg6sw3WgWqgpGD1Yeqklpd9Zv5MzfozUKP10I3ZeV6hTwe+m/0ajfHANyPih5J+AayUdAGwCTgnbX8r2U/WNpL9bO38HGOzChgsQff3n3fpXnlpW3jpvmopIfg//WLxZ2q2v+OOPIzmcaJn30sJvHmcOO7Iw4bcN7eEHhGPA8f3U74NOLWf8gAuyiseq7yhEnRftXCv3MysmqZOauHLZx/Pp25uZ1yT2NcbfOms4d2OqsTv0K1BjTRBlzpQ+b6qmTWy0d6OckK33IwmQfu+qpnZ6G5HOaFbrkaToH1f1cxs5JzQLXdO0GZm+fNsa2ZmZgXghG5mhqd4tfrnJncza3iezc+KwFfoZtbQPJufFYUTupk1tFqYzc8a01jf5nGTu5k1NI9QaNWQx20eX6GbWUMbaIpX/9SyvtVyJ8e8bvP4Ct3MGp5HKCyWWu/kONJ5LobLCd3MDA+AVBQjneWxGvK6zeMmdzMzK4x66OSY120eX6GbmVlh1Esnxzxu8/gK3czMCqOeOjlOndTC8TMnj1lsvkI3M7NCadROjk7oZmZWOI3YyTH3JndJ4yTdL+mW9PyVku6WtFHSTZIOSuUt6fnGtH5W3rGZmZkVRSXuoX8UeLjs+WXA5RExG9gOXJDKLwC2p/LL03Y2xmp5sAUzMxu9XBO6pDbgbcA16bmAU4Cb0yYrgDPT8hnpOWn9qWl7GyOrN3Sy4LJ1vO+au1lw2TrWbOisdkhmZjZG8r5CvwJYAi8OhzMV2BERe9PzDqA0fM8MYDNAWv9c2n4/ki6UtF7S+q6urhxDLxbPKHVgXO+s0lznbKRyS+iS3g5sjYh7x/K4EbE8IuZHxPzW1taxPHSh1cNgC7XM9c4qzXXORirPXu4LgIWSTgcmAIcCVwKTJY1PV+FtQKndtxOYCXRIGg8cBmzLMb6GUi+DLZiZ2ejkdoUeEZ+OiLaImAWcC6yLiPcCtwNnpc0WA6vT8pr0nLR+XUREXvE1mnoabMHMzEauGr9DXwrcKOnzwP3Atan8WuAGSRuBZ8m+BNgYatTBFszMGkFFEnpE3AHckZYfB07sZ5s9wNmViKeRNeJgC2ZmjcBjuZuZmRWAE7qZmVkBOKGbmZkVgBO6mZlZAYw4oUtqknRoHsHY0DwWu5mZ9WdYvdwlfRP4ILAP+AVwqKQrI+JLeQZXq7bt6q7KT79Wb+hk6ap2mpua6OntZdmiuSycN2PoHc3MrPCGe4V+bEQ8TzaRyg+AVwLvzyuoWlatCU48FruZmQ1muAm9WVIzWUJfExE9QMON4lbNpOqx2M3MbDDDTehfA54EDgZ+Kulo4Pm8gqpV1UyqHovdzMwGM6yEHhFfiYgZEXF6ZDYBf5ZzbDWnmknVY7GbmdlghtspbjrwD8CREXGapGOBN/LSOOwNoZRUl/TpmFappOqx2M3MbCDDHcv9OuDrwP9Mzx8BbqLBEjpUP6l6LHYzM+vPcO+hT4uIlUAvQJrLfF9uUdW4qZNaOH7mZCdWMzOrGcNN6L+XNJXUs13SScBzuUVlZmaWgyIPzjXcJvePA2uAYyT9O9AKnJVbVGZmZmOs6INzDSuhR8R9kv4UeDUg4Dfpt+hmZmY1r3wckT3Z3WOWrGpnwexphbl9OmhCl/TOAVa9ShIR8Z0cYjIzMxtTpXFESskcXhpHpCESOvCOQdYF4IRuZmY1rxEG5xo0oUfE+aM9sKQJwE+BlvQ6N0fEZyS9ErgRmArcC7w/Iv5TUgtwPfBHwDbgXRHx5Ghf38zMrKTa44hUwnA7xSHpbcBxwIRSWUR8bpBduoFTImJXGgf+Tkk/IOtgd3lE3Cjpq8AFwNXp7/aImC3pXOAy4F0jfkc2KtWaQc7MrFKqPY5I3oY7UtxXgVeQDfd6DVkP93sG2yciAtiVnjanRwCnAO9J5SuAS8gS+hlpGeBm4CpJSsexHBW956eZWUmRB+ca7u/Q/zgiziO7gv4s2bCvrxpqJ0njJG0AtgI/Bh4DdqSBaQA6gFLmmAFshhcHrnmOrFm+7zEvlLRe0vqurq5hhm8D8bSsw+N6Z5XmOmcjNdyEXppO7AVJRwJ7gSOG2iki9kXEPKANOBF4zWiC7HPM5RExPyLmt7a2HujhGp6nZR0e1zurNNc5G6nhJvRbJE0GlpF1ZHsC+NZwXyQidgC3k13ZT5ZUaupvAzrTcicwEyCtP4ysc9yYKPLoQAeiEXp+mpk1gkETuqTXS/qvEfG/U1KeBDwIfBu4fIh9W9OXACRNBP4b8DBZYi+NMrcYWJ2W16TnpPXrxur++eoNnSy4bB3vu+ZuFly2jjUbOofeqUF4WlYzs2IYqlPc14A/B5B0MnAp8GFgHrCcwYd/PQJYIWkc2ReHlRFxi6RfATdK+jxwPy/N2HYtcIOkjcCzwLmjekd9NMLoQAeq6D0/zcwawVAJfVxEPJuW3wUsj4hVwKrU2W1AEdEOvK6f8sfJ7qf3Ld8DnD2coEeiEUYHGgtF7vlpZtYIhrqHPq7sfvepwLqydcP+DXs1+R6xmZk1gqES+reAn0haTdbT/WcAkmZTJ9On+h6xmZk1gqGGfv2CpLVk98NvK+uk1kR2L70u+B6xmZkV3ZDN5hFxVz9lj+QTTn58j9jMzIpsuL9DNzMzsxrmhG5mZlYATuhmZmYF4IRuZmZWAE7oZmZmBeCEbmZmVgBO6GZmZgXghG5mZlYATuhmZmYF4IRuZmZWAE7oZmZmBeCEbmZmVgBO6GZmZgXghG5mZlYAuSV0STMl3S7pV5IekvTRVH64pB9LejT9nZLKJekrkjZKapd0Ql6xmZmZFU2eV+h7gU9ExLHAScBFko4FLgbWRsQcYG16DnAaMCc9LgSuzjE2MzOzQsktoUfEloi4Ly3vBB4GZgBnACvSZiuAM9PyGcD1kbkLmCzpiLziMzMzK5KK3EOXNAt4HXA3MD0itqRVTwHT0/IMYHPZbh2prO+xLpS0XtL6rq6u/II2K+N6Z5XmOmcjlXtClzQJWAV8LCKeL18XEQHESI4XEcsjYn5EzG9tbR3DSM0G5npnleY6ZyOVa0KX1EyWzL8REd9JxU+XmtLT362pvBOYWbZ7WyozMzOzIeTZy13AtcDDEfGPZavWAIvT8mJgdVn5eam3+0nAc2VN82ZmZjaI8TkeewHwfuBBSRtS2d8BlwIrJV0AbALOSetuBU4HNgIvAOfnGJuZmVmh5JbQI+JOQAOsPrWf7QO4KK94zMzMiswjxZmZmRWAE7qZmVkBOKGbmZkVgBO6mZlZATihm5mZFUDhEvq2Xd08sHkH23Z1VzsUMzOzisnzd+gVt3pDJ0tXtdPc1ERPby/LFs1l4byXDQdvZmZWOIW5Qt+2q5ulq9rZ09PLzu697OnpZcmqdl+pm5lZQyhMQu/Yvpvmpv3fTnNTEx3bd1cpIjMzs8opTEJvmzKRnt7e/cp6entpmzKxShGZmZlVTmES+tRJLSxbNJcJzU0c0jKeCc1NLFs0l6mTWqodmpmZWe4K1Slu4bwZLJg9jY7tu2mbMtHJ3MzMGkahEjpkV+pO5GZm1mgK0+RuZmbWyJzQzczMCsAJ3czMrACc0M3MzAogt4Qu6V8kbZX0y7KywyX9WNKj6e+UVC5JX5G0UVK7pBPyisvMzKyI8rxCvw54a5+yi4G1ETEHWJueA5wGzEmPC4Grc4zLzMyscHJL6BHxU+DZPsVnACvS8grgzLLy6yNzFzBZ0hF5xWZmZlY0lb6HPj0itqTlp4DpaXkGsLlsu45UZmZmZsNQtU5xERFAjHQ/SRdKWi9pfVdXVw6Rmb2c651VmuucjVSlE/rTpab09HdrKu8EZpZt15bKXiYilkfE/IiY39rammuwZiWud1ZprnM2UpVO6GuAxWl5MbC6rPy81Nv9JOC5sqZ5MzMzG0JuY7lL+hbwZmCapA7gM8ClwEpJFwCbgHPS5rcCpwMbgReA8/OKy8zMrIhyS+gR8e4BVp3az7YBXJRXLGZmZkXnkeLMzMwKwAndzMysAJzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAnNDNzMwKwAndzMysAJzQzczMCkDZIG31SVIX2RCyA5kGPFOhcEbLMY6NoyOiIjNYuN5VTD3EWJF65zpXUbUe54B1rq4T+lAkrY+I+dWOYzCOsXjq4Xw5xmKph3NVDzFC/cTZHze5m5mZFYATupmZWQEUPaEvr3YAw+AYi6cezpdjLJZ6OFf1ECPUT5wvU+h76GZmZo2i6FfoZmZmDcEJ3czMrAAKldAljZN0v6Rb0vNXSrpb0kZJN0k6qAZinCzpZkm/lvSwpDdKOlzSjyU9mv5OqXKMfyvpIUm/lPQtSRNq8VzWilqvd65zxVPrdS7F5HpXYYVK6MBHgYfLnl8GXB4Rs4HtwAVViWp/VwI/jIjXAMeTxXsxsDYi5gBr0/OqkDQD+AgwPyJeC4wDzqU2z2WtqPV65zpXPLVe58D1rvIiohAPoI2sgpwC3AKIbLSf8Wn9G4EfVTnGw4AnSJ0Ry8p/AxyRlo8AflPFGGcAm4HDgfHpXP5FrZ3LWnnUer1znSveo9brXIrB9a4KjyJdoV8BLAF60/OpwI6I2Jued5B9gNX0SqAL+HpqLrtG0sHA9IjYkrZ5CpherQAjohP4P8BvgS3Ac8C91N65rBVXUNv1znWueK6gtuscuN5VRSESuqS3A1sj4t5qxzKE8cAJwNUR8Trg9/Rpcorsa2HVfkuY7mmdQfYP8kjgYOCt1YqnltVJvXOdK5A6qXPgelcVhUjowAJgoaQngRvJmqKuBCZLGp+2aQM6qxPeizqAjoi4Oz2/mazSPy3pCID0d2uV4gP4c+CJiOiKiB7gO2Tnt9bOZS2oh3rnOlcs9VDnwPWuKgqR0CPi0xHRFhGzyDo1rIuI9wK3A2elzRYDq6sUIgAR8RSwWdKrU9GpwK+ANWTxQfXj/C1wkqRXSBIvxVhT57IW1EO9c50rlnqoc+B6VzXVvok/1g/gzcAtafkPgHuAjcC3gZYaiG8esB5oB74HTCG7B7YWeBT4N+DwKsf4WeDXwC+BG4CWWjyXtfSo5XrnOlfMRy3XuRST612FHx761czMrAAK0eRuZmbW6JzQzczMCsAJ3czMrACc0M3MzArACd3MzKwAGiqhSwpJXy57/klJl4zRsa+TdNbQWx7w65ydZi66vZ91r5J0a5rJ6D5JKyWNeGhFSR+QdOTYRGxmZpXQUAkd6AbeKWlatQMpVzYq0XBcAPxVRPxZn2NMAP4v2VCLcyLiBOCfgNZRhPQBsqEQK0KZRquLZmZjqtH+E90LLAf+tu+KvlfYknalv2+W9BNJqyU9LulSSe+VdI+kByUdU3aYP5e0XtIjaczl0rzFX5L0C0ntkv5H2XF/JmkN2ehEfeN5dzr+LyVdlsr+HngTcK2kL/XZ5T3AzyPi+6WCiLgjIn6ZrrivKjv2Len1x6X3/cv0Wn+bzsF84BuSNkiaKOnUNMHCg5L+RVJLOs6Tkr6Ytlsv6QRJP5L0mKQPlr3ep8re/2dT2SxJv5F0PdmgDjP7xjL0x2lmZiUjuTIsiv8PaJe0bAT7HA/8IfAs8DhwTUScKOmjwIeBj6XtZgEnAscAt0uaDZwHPBcRr0+J8N8l3Za2PwF4bUQ8Uf5iqbn7MuCPyObjvU3SmRHxOUmnAJ+MiPV9Ynwt2UxBIzEPmBHZXMBImhwROyR9qPQa6cr/OuDUiHgkJeC/JpvxCeC3ETFP0uVpuwXABLIk/VVJbwHmpPMiYI2kk8mGXZwDLI6IuyT9Ud9YRvhezMwaWqNdoRMRzwPXk01sP1y/iIgtEdENPAaUEvKDZEm8ZGVE9EbEo2SJ/zXAW4DzJG0A7iYb+nBO2v6evsk8eT1wR2STBuwFvgGcPIJ4h+tx4A8k/b+S3go83882ryabwOCR9HxFn1jWpL8PAndHxM6I6AK6U1J+S3rcD9xHdk5K739TRNw1gljMzGwADZfQkyvI7kUfXFa2l3Q+0v3cg8rWdZct95Y972X/Vo6+4+gG2VXphyNiXnq8MiJKXwh+fyBvoo+HyK7o+/Pie0smAETEdrLWhzuADwLXjOJ1y89F3/M0nuz9f7Hs/c+OiGvTNi++/zGKxcysYTVkQo+IZ4GVZEm95EleSogLgeZRHPpsSU3pvvofAL8BfgT8taRmeLEn+sGDHYRsYoA/lTRN0jjg3cBPhtjnm8AfS3pbqUDSyZJeS/be5qXYZpI1f5M6BzZFxCrgf5HdAgDYCRySln8DzEq3DwDeP4xYyv0I+EtJk9JrzpD0X/puNEgsZmY2DI14D73ky8CHyp7/M7Ba0gPADxnd1fNvyZLxocAHI2KPpGvImuXvkySgCzhzsINExBZJF5NN4yfg/0bEoFP4RcTu1BHvCklXAD1ksxx9lOzq/QmyzncPkzV9A8wAvl7Ww/zT6e91ZPe/dwNvBM4Hvp164/8C+OqQZ+KluG6T9IfAz7O3zy7gfcC+PpsOFIuZmQ2DZ1szMzMrgIZscjczMysaJ3QzM7MCcEI3MzMrACd0MzOzAnBCNzMzKwAndDMzswJwQjczMyuA/x8fpbRvMJSbWQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 576x432 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes, = plt.subplots(2,3, sharex=True, sharey=True, figsize=(8,6))\n",
"\n",
"for i, ax in enumerate(axes.ravel()):\n",
" location_filter = (hierarchical_salad_df[\"location\"] == i)\n",
" hierarchical_salad_df[location_filter].plot(kind=\"scatter\", x=\"customers\", y=\"sales\", ax=ax)\n",
" ax.set_xlabel(\"\")\n",
" ax.set_ylabel(\"\")\n",
"\n",
"fig.suptitle(\"A bunch of simulated data \\n for a Hierarchical model\") \n",
"axes[1,0].set_xlabel(\"Number of Customers\")\n",
"axes[1,0].set_ylabel(\"Sales\");"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "super-aquarium",
"metadata": {},
"outputs": [],
"source": [
"customers = hierarchical_salad_df.loc[:, \"customers\"].values\n",
"sales = hierarchical_salad_df.loc[:, \"sales\"].values\n",
"location_category = pd.Categorical(hierarchical_salad_df[\"location\"])"
]
},
{
"cell_type": "markdown",
"id": "needed-feature",
"metadata": {},
"source": [
"# Sample Non Centered"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "forty-showcase",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/sampling.py:465: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n",
" warnings.warn(\n",
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [β_offset, β_σ_hyperprior, β_μ_hyperprior, σ]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [8000/8000 00:03<00:00 Sampling 4 chains, 81 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4 seconds.\n",
"There were 17 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 10 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 27 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 27 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"The number of effective samples is smaller than 25% for some parameters.\n"
]
}
],
"source": [
"with pm.Model() as model_hierarchical_salad_sales_predictions:\n",
" \n",
" σ = pm.HalfNormal(\"σ\", 20)\n",
" \n",
" β_μ_hyperprior = pm.Normal(\"β_μ_hyperprior\", 10, 10)\n",
" β_σ_hyperprior = pm.HalfNormal(\"β_σ_hyperprior\", 10)\n",
" β_offset = pm.Normal('β_offset', mu=0, sd=1, shape=6)\n",
" \n",
" β = pm.Deterministic(\"β\", β_μ_hyperprior + β_offset * β_σ_hyperprior)\n",
" \n",
" μ = pm.Deterministic('μ', β[location_category.codes] * hierarchical_salad_df.customers)\n",
" \n",
" sales = pm.Normal(\"sales\", mu=μ, sd=σ, observed=hierarchical_salad_df.sales)\n",
" \n",
" trace_hierarchical_salad_sales_noncentered = pm.sample(random_seed=0)\n",
" \n",
" inf_data_hierarchical_salad_sales_noncentered = az.from_pymc3(trace=trace_hierarchical_salad_sales_noncentered, \n",
" coords={\"β_dim_0\":location_category.categories})"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "focused-hampshire",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"419pt\" height=\"452pt\"\n",
" viewBox=\"0.00 0.00 418.54 451.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 447.86)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-447.86 414.54,-447.86 414.54,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>cluster6</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M162.42,-232.91C162.42,-232.91 246.42,-232.91 246.42,-232.91 252.42,-232.91 258.42,-238.91 258.42,-244.91 258.42,-244.91 258.42,-423.86 258.42,-423.86 258.42,-429.86 252.42,-435.86 246.42,-435.86 246.42,-435.86 162.42,-435.86 162.42,-435.86 156.42,-435.86 150.42,-429.86 150.42,-423.86 150.42,-423.86 150.42,-244.91 150.42,-244.91 150.42,-238.91 156.42,-232.91 162.42,-232.91\"/>\n",
"<text text-anchor=\"middle\" x=\"246.92\" y=\"-240.71\" font-family=\"Times,serif\" font-size=\"14.00\">6</text>\n",
"</g>\n",
"<g id=\"clust2\" class=\"cluster\">\n",
"<title>cluster41</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M162.42,-8C162.42,-8 246.42,-8 246.42,-8 252.42,-8 258.42,-14 258.42,-20 258.42,-20 258.42,-209.93 258.42,-209.93 258.42,-215.93 252.42,-221.93 246.42,-221.93 246.42,-221.93 162.42,-221.93 162.42,-221.93 156.42,-221.93 150.42,-215.93 150.42,-209.93 150.42,-209.93 150.42,-20 150.42,-20 150.42,-14 156.42,-8 162.42,-8\"/>\n",
"<text text-anchor=\"middle\" x=\"243.42\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">41</text>\n",
"</g>\n",
"<!-- β_μ_hyperprior -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>β_μ_hyperprior</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"71.42\" cy=\"-390.38\" rx=\"71.34\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"71.42\" y=\"-401.68\" font-family=\"Times,serif\" font-size=\"14.00\">β_μ_hyperprior</text>\n",
"<text text-anchor=\"middle\" x=\"71.42\" y=\"-386.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"71.42\" y=\"-371.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- β -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>β</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"249.92,-316.91 158.92,-316.91 158.92,-263.91 249.92,-263.91 249.92,-316.91\"/>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-301.71\" font-family=\"Times,serif\" font-size=\"14.00\">β</text>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-286.71\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-271.71\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- β_μ_hyperprior&#45;&gt;β -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>β_μ_hyperprior&#45;&gt;β</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M112.04,-359.45C127.66,-347.95 145.52,-334.79 161.36,-323.13\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"163.45,-325.93 169.43,-317.18 159.3,-320.29 163.45,-325.93\"/>\n",
"</g>\n",
"<!-- σ -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>σ</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"326.42\" cy=\"-187.43\" rx=\"58.88\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"326.42\" y=\"-198.73\" font-family=\"Times,serif\" font-size=\"14.00\">σ</text>\n",
"<text text-anchor=\"middle\" x=\"326.42\" y=\"-183.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"326.42\" y=\"-168.73\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- sales -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>sales</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"206.42\" cy=\"-76.48\" rx=\"41.94\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"206.42\" y=\"-87.78\" font-family=\"Times,serif\" font-size=\"14.00\">sales</text>\n",
"<text text-anchor=\"middle\" x=\"206.42\" y=\"-72.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"206.42\" y=\"-57.78\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- σ&#45;&gt;sales -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>σ&#45;&gt;sales</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M293.31,-156.37C277.69,-142.19 259,-125.22 242.97,-110.66\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"245.09,-107.86 235.34,-103.73 240.39,-113.05 245.09,-107.86\"/>\n",
"</g>\n",
"<!-- β_σ_hyperprior -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>β_σ_hyperprior</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"338.42\" cy=\"-390.38\" rx=\"72.25\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"338.42\" y=\"-401.68\" font-family=\"Times,serif\" font-size=\"14.00\">β_σ_hyperprior</text>\n",
"<text text-anchor=\"middle\" x=\"338.42\" y=\"-386.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"338.42\" y=\"-371.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- β_σ_hyperprior&#45;&gt;β -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>β_σ_hyperprior&#45;&gt;β</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M297.49,-359.45C281.76,-347.95 263.75,-334.79 247.8,-323.13\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"249.81,-320.26 239.67,-317.18 245.68,-325.91 249.81,-320.26\"/>\n",
"</g>\n",
"<!-- μ -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>μ</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"249.92,-213.93 158.92,-213.93 158.92,-160.93 249.92,-160.93 249.92,-213.93\"/>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-198.73\" font-family=\"Times,serif\" font-size=\"14.00\">μ</text>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-183.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-168.73\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- β&#45;&gt;μ -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>β&#45;&gt;μ</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M204.42,-263.66C204.42,-251.68 204.42,-237.22 204.42,-224.19\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"207.92,-224.12 204.42,-214.12 200.92,-224.12 207.92,-224.12\"/>\n",
"</g>\n",
"<!-- β_offset -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>β_offset</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"204.42\" cy=\"-390.38\" rx=\"43.68\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-401.68\" font-family=\"Times,serif\" font-size=\"14.00\">β_offset</text>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-386.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"204.42\" y=\"-371.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- β_offset&#45;&gt;β -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>β_offset&#45;&gt;β</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M204.42,-352.9C204.42,-344.48 204.42,-335.54 204.42,-327.16\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"207.92,-326.92 204.42,-316.92 200.92,-326.92 207.92,-326.92\"/>\n",
"</g>\n",
"<!-- μ&#45;&gt;sales -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>μ&#45;&gt;sales</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M204.89,-160.89C205.09,-149.98 205.33,-136.89 205.56,-124.35\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"209.06,-124.06 205.75,-114 202.06,-123.94 209.06,-124.06\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x7f9ec3ac18b0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.model_to_graphviz(model_hierarchical_salad_sales_predictions)"
]
},
{
"cell_type": "markdown",
"id": "modular-consolidation",
"metadata": {},
"source": [
"# Question 1\n",
"Is this an abuse of the API to generate posterior predictive for the group level and location level by adding vars after sampling?"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "tired-smart",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='4000' class='' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [4000/4000 00:07<00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"out_of_sample_customers = 50\n",
"\n",
"with model_hierarchical_salad_sales_predictions:\n",
" β_group = pm.Normal(\"group_beta_prediction\", β_μ_hyperprior, β_σ_hyperprior)\n",
" group_level_prediction = pm.Normal(\"group_level_prediction\", β_group*out_of_sample_customers, σ)\n",
" \n",
" location_4_prediction = pm.Normal(\"location_4_prediction\", β[4]*out_of_sample_customers, σ)\n",
" \n",
" ppc = pm.sample_posterior_predictive(trace_hierarchical_salad_sales_noncentered, var_names=['group_level_prediction', 'location_4_prediction'], )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "outer-level",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-11893.11106697)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_hierarchical_salad_sales_predictions.logp(model_hierarchical_salad_sales_predictions.test_point)"
]
},
{
"cell_type": "markdown",
"id": "sporting-float",
"metadata": {},
"source": [
"# Question 2\n",
"Sampling is worse in this model when the extra nodes are included, even though they theoretically shouldn't contribute to the likelihood. Notably the number of divergences increases dramatically. For the broad question do these results make sense?"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "acute-token",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/sampling.py:465: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n",
" warnings.warn(\n",
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [location_4_predictions, group_prediction, group_beta_prediction, β_offset, β_σ_hyperprior, β_μ_hyperprior, σ]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [8000/8000 00:04<00:00 Sampling 4 chains, 1,008 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 5 seconds.\n",
"There were 149 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 506 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"The acceptance probability does not match the target. It is 0.3882465788481101, but should be close to 0.8. Try to increase the number of tuning steps.\n",
"There were 175 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 178 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.\n",
"The estimated number of effective samples is smaller than 200 for some parameters.\n"
]
}
],
"source": [
"# If I used shared vars I get a shape error\n",
"customers = hierarchical_salad_df.loc[:, \"customers\"].values\n",
"customers = theano.shared(customers)\n",
"out_of_sample_customers = 50\n",
"\n",
"with pm.Model() as model_hierarchical_salad_sales_extra_nodes:\n",
" \n",
" σ = pm.HalfNormal(\"σ\", 20)\n",
" \n",
" β_μ_hyperprior = pm.Normal(\"β_μ_hyperprior\", 10, 10)\n",
" β_σ_hyperprior = pm.HalfNormal(\"β_σ_hyperprior\", 10)\n",
" β_offset = pm.Normal('β_offset', mu=0, sd=1, shape=6)\n",
" \n",
" β = pm.Deterministic(\"β\", β_μ_hyperprior + β_offset * β_σ_hyperprior)\n",
" \n",
" μ = pm.Deterministic('μ', β[location_category.codes] * customers)\n",
" \n",
" sales = pm.Normal(\"sales\", mu=μ, sd=σ, observed=hierarchical_salad_df.sales)\n",
" \n",
" # Extra nodes for group and individual level predictions\n",
" β_group = pm.Normal(\"group_beta_prediction\", β_μ_hyperprior, β_σ_hyperprior)\n",
" group_prediction = pm.Normal(\"group_prediction\", β_group*out_of_sample_customers, σ)\n",
" location_4_predictions = pm.Normal(\"location_4_predictions\", β[4]*out_of_sample_customers, σ)\n",
" \n",
" trace_hierarchical_salad_sales_noncentered = pm.sample(random_seed=0)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "rocky-mistress",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-11893.11106697)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_hierarchical_salad_sales_extra_nodes.logp(model_hierarchical_salad_sales_extra_nodes.test_point)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "electric-concentrate",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"611pt\" height=\"471pt\"\n",
" viewBox=\"0.00 0.00 611.29 470.81\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 466.81)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-466.81 607.29,-466.81 607.29,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>cluster6</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M353.29,-240.88C353.29,-240.88 437.29,-240.88 437.29,-240.88 443.29,-240.88 449.29,-246.88 449.29,-252.88 449.29,-252.88 449.29,-442.81 449.29,-442.81 449.29,-448.81 443.29,-454.81 437.29,-454.81 437.29,-454.81 353.29,-454.81 353.29,-454.81 347.29,-454.81 341.29,-448.81 341.29,-442.81 341.29,-442.81 341.29,-252.88 341.29,-252.88 341.29,-246.88 347.29,-240.88 353.29,-240.88\"/>\n",
"<text text-anchor=\"middle\" x=\"437.79\" y=\"-248.68\" font-family=\"Times,serif\" font-size=\"14.00\">6</text>\n",
"</g>\n",
"<g id=\"clust2\" class=\"cluster\">\n",
"<title>cluster41</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M499.29,-8C499.29,-8 583.29,-8 583.29,-8 589.29,-8 595.29,-14 595.29,-20 595.29,-20 595.29,-209.93 595.29,-209.93 595.29,-215.93 589.29,-221.93 583.29,-221.93 583.29,-221.93 499.29,-221.93 499.29,-221.93 493.29,-221.93 487.29,-215.93 487.29,-209.93 487.29,-209.93 487.29,-20 487.29,-20 487.29,-14 493.29,-8 499.29,-8\"/>\n",
"<text text-anchor=\"middle\" x=\"580.29\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">41</text>\n",
"</g>\n",
"<!-- β_μ_hyperprior -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>β_μ_hyperprior</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"262.29\" cy=\"-409.34\" rx=\"71.34\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"262.29\" y=\"-420.64\" font-family=\"Times,serif\" font-size=\"14.00\">β_μ_hyperprior</text>\n",
"<text text-anchor=\"middle\" x=\"262.29\" y=\"-405.64\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"262.29\" y=\"-390.64\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- group_beta_prediction -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>group_beta_prediction</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"98.29\" cy=\"-298.38\" rx=\"98.08\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"98.29\" y=\"-309.68\" font-family=\"Times,serif\" font-size=\"14.00\">group_beta_prediction</text>\n",
"<text text-anchor=\"middle\" x=\"98.29\" y=\"-294.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"98.29\" y=\"-279.68\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- β_μ_hyperprior&#45;&gt;group_beta_prediction -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>β_μ_hyperprior&#45;&gt;group_beta_prediction</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M218.77,-379.43C199.17,-366.4 175.78,-350.86 154.84,-336.95\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"156.61,-333.93 146.34,-331.31 152.74,-339.76 156.61,-333.93\"/>\n",
"</g>\n",
"<!-- β -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>β</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"440.79,-324.88 349.79,-324.88 349.79,-271.88 440.79,-271.88 440.79,-324.88\"/>\n",
"<text text-anchor=\"middle\" x=\"395.29\" y=\"-309.68\" font-family=\"Times,serif\" font-size=\"14.00\">β</text>\n",
"<text text-anchor=\"middle\" x=\"395.29\" y=\"-294.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"395.29\" y=\"-279.68\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- β_μ_hyperprior&#45;&gt;β -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>β_μ_hyperprior&#45;&gt;β</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M300.04,-377.41C317.49,-363.11 338.21,-346.14 355.86,-331.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"358.21,-334.28 363.73,-325.23 353.78,-328.86 358.21,-334.28\"/>\n",
"</g>\n",
"<!-- σ -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>σ</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"273.29\" cy=\"-298.38\" rx=\"58.88\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"273.29\" y=\"-309.68\" font-family=\"Times,serif\" font-size=\"14.00\">σ</text>\n",
"<text text-anchor=\"middle\" x=\"273.29\" y=\"-294.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"273.29\" y=\"-279.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- group_prediction -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>group_prediction</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"114.29\" cy=\"-187.43\" rx=\"77.56\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"114.29\" y=\"-198.73\" font-family=\"Times,serif\" font-size=\"14.00\">group_prediction</text>\n",
"<text text-anchor=\"middle\" x=\"114.29\" y=\"-183.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"114.29\" y=\"-168.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- σ&#45;&gt;group_prediction -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>σ&#45;&gt;group_prediction</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M239.59,-267.46C228.92,-258.52 216.88,-248.95 205.29,-240.88 195.47,-234.05 184.7,-227.27 174.12,-220.97\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"175.7,-217.83 165.3,-215.8 172.16,-223.87 175.7,-217.83\"/>\n",
"</g>\n",
"<!-- location_4_predictions -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>location_4_predictions</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"379.29\" cy=\"-187.43\" rx=\"98.99\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"379.29\" y=\"-198.73\" font-family=\"Times,serif\" font-size=\"14.00\">location_4_predictions</text>\n",
"<text text-anchor=\"middle\" x=\"379.29\" y=\"-183.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"379.29\" y=\"-168.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- σ&#45;&gt;location_4_predictions -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>σ&#45;&gt;location_4_predictions</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M303.66,-266.16C314.5,-255.02 326.9,-242.28 338.48,-230.38\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"341.09,-232.71 345.56,-223.1 336.08,-227.83 341.09,-232.71\"/>\n",
"</g>\n",
"<!-- sales -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>sales</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"539.29\" cy=\"-76.48\" rx=\"41.94\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"539.29\" y=\"-87.78\" font-family=\"Times,serif\" font-size=\"14.00\">sales</text>\n",
"<text text-anchor=\"middle\" x=\"539.29\" y=\"-72.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"539.29\" y=\"-57.78\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- σ&#45;&gt;sales -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>σ&#45;&gt;sales</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M261.42,-261.67C253.03,-229.17 247.09,-181.56 271.29,-149.95 297.47,-115.75 417.7,-93.98 487.91,-83.95\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"488.47,-87.41 497.89,-82.56 487.5,-80.48 488.47,-87.41\"/>\n",
"</g>\n",
"<!-- group_beta_prediction&#45;&gt;group_prediction -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>group_beta_prediction&#45;&gt;group_prediction</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M103.66,-260.8C104.87,-252.54 106.18,-243.65 107.44,-235.04\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"110.91,-235.5 108.9,-225.1 103.99,-234.49 110.91,-235.5\"/>\n",
"</g>\n",
"<!-- β_σ_hyperprior -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>β_σ_hyperprior</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"99.29\" cy=\"-409.34\" rx=\"72.25\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"99.29\" y=\"-420.64\" font-family=\"Times,serif\" font-size=\"14.00\">β_σ_hyperprior</text>\n",
"<text text-anchor=\"middle\" x=\"99.29\" y=\"-405.64\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"99.29\" y=\"-390.64\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- β_σ_hyperprior&#45;&gt;group_beta_prediction -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>β_σ_hyperprior&#45;&gt;group_beta_prediction</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M98.95,-371.75C98.88,-363.58 98.8,-354.8 98.72,-346.27\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"102.22,-346.02 98.62,-336.06 95.22,-346.09 102.22,-346.02\"/>\n",
"</g>\n",
"<!-- β_σ_hyperprior&#45;&gt;β -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>β_σ_hyperprior&#45;&gt;β</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M151.83,-383.45C161.81,-379.23 172.27,-375.16 182.29,-371.86 251.1,-349.18 274.79,-364.63 341.29,-335.86 345.14,-334.19 349.01,-332.24 352.82,-330.12\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"354.73,-333.06 361.53,-324.94 351.15,-327.05 354.73,-333.06\"/>\n",
"</g>\n",
"<!-- β&#45;&gt;location_4_predictions -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>β&#45;&gt;location_4_predictions</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M391.54,-271.84C389.92,-260.82 387.97,-247.58 386.12,-234.92\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"389.57,-234.34 384.65,-224.95 382.64,-235.36 389.57,-234.34\"/>\n",
"</g>\n",
"<!-- μ -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>μ</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"586.79,-213.93 495.79,-213.93 495.79,-160.93 586.79,-160.93 586.79,-213.93\"/>\n",
"<text text-anchor=\"middle\" x=\"541.29\" y=\"-198.73\" font-family=\"Times,serif\" font-size=\"14.00\">μ</text>\n",
"<text text-anchor=\"middle\" x=\"541.29\" y=\"-183.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"541.29\" y=\"-168.73\" font-family=\"Times,serif\" font-size=\"14.00\">Deterministic</text>\n",
"</g>\n",
"<!-- β&#45;&gt;μ -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>β&#45;&gt;μ</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M429.52,-271.84C450.22,-256.39 476.78,-236.57 498.82,-220.12\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"501.1,-222.79 507.02,-214 496.91,-217.18 501.1,-222.79\"/>\n",
"</g>\n",
"<!-- β_offset -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>β_offset</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"395.29\" cy=\"-409.34\" rx=\"43.68\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"395.29\" y=\"-420.64\" font-family=\"Times,serif\" font-size=\"14.00\">β_offset</text>\n",
"<text text-anchor=\"middle\" x=\"395.29\" y=\"-405.64\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"395.29\" y=\"-390.64\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- β_offset&#45;&gt;β -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>β_offset&#45;&gt;β</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M395.29,-371.75C395.29,-360.02 395.29,-347.03 395.29,-335.34\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"398.79,-335.04 395.29,-325.04 391.79,-335.04 398.79,-335.04\"/>\n",
"</g>\n",
"<!-- μ&#45;&gt;sales -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>μ&#45;&gt;sales</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M540.82,-160.89C540.62,-149.98 540.38,-136.89 540.15,-124.35\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"543.64,-123.94 539.96,-114 536.64,-124.06 543.64,-123.94\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x7f9ebd6f7b80>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.model_to_graphviz(model_hierarchical_salad_sales_extra_nodes)"
]
},
{
"cell_type": "markdown",
"id": "flying-volunteer",
"metadata": {},
"source": [
"# The actual right way to do it\n",
"I know this is the right way, just asking about the other 2"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "disturbed-works",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/sampling.py:465: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n",
" warnings.warn(\n",
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [β_offset, β_σ_hyperprior, β_μ_hyperprior, σ]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [8000/8000 00:03<00:00 Sampling 4 chains, 81 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"/home/canyon/miniconda3/envs/cargo/lib/python3.9/site-packages/pymc3/math.py:246: RuntimeWarning: divide by zero encountered in log1p\n",
" return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))\n",
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.\n",
"There were 17 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 10 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 27 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 27 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"The number of effective samples is smaller than 25% for some parameters.\n"
]
}
],
"source": [
"# If I used shared vars I get a shape error\n",
"customers = hierarchical_salad_df.loc[:, \"customers\"].values\n",
"customers = theano.shared(customers)\n",
"\n",
"with pm.Model() as model_hierarchical_salad_sales_extra_nodes:\n",
" \n",
" σ = pm.HalfNormal(\"σ\", 20)\n",
" \n",
" β_μ_hyperprior = pm.Normal(\"β_μ_hyperprior\", 10, 10)\n",
" β_σ_hyperprior = pm.HalfNormal(\"β_σ_hyperprior\", 10)\n",
" β_offset = pm.Normal('β_offset', mu=0, sd=1, shape=6)\n",
" \n",
" β = pm.Deterministic(\"β\", β_μ_hyperprior + β_offset * β_σ_hyperprior)\n",
" \n",
" μ = pm.Deterministic('μ', β[location_category.codes] * customers)\n",
" \n",
" sales = pm.Normal(\"sales\", mu=μ, sd=σ, observed=hierarchical_salad_df.sales)\n",
" trace_hierarchical_salad_sales_noncentered = pm.sample(random_seed=0)\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "impaired-bleeding",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='4000' class='' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [4000/4000 00:00<00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"customers.set_value([50])\n",
"with pm.Model() as model_hierarchical_salad_sales_extra_nodes:\n",
" ppc = pm.sample_posterior_predictive(trace_hierarchical_salad_sales_noncentered)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment