Skip to content

Instantly share code, notes, and snippets.

@mamacneil
Created January 28, 2021 00:53
Show Gist options
  • Save mamacneil/676351ea5ada3283c3b9aef31124daeb to your computer and use it in GitHub Desktop.
Save mamacneil/676351ea5ada3283c3b9aef31124daeb to your computer and use it in GitHub Desktop.
Dirichlet-multinomial model implementation from Isaac Slavitt
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Isaac Slavitt's compositional example\n",
"\n",
"Compositional data analysis example from Isaac Slavitt https://www.isaacslavitt.com/posts/dirichlet-multinomial-for-skittle-proportions/"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import python packages\n",
"%matplotlib inline\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pymc3 as pm\n",
"import theano as T\n",
"import theano.tensor as tt\n",
"from pymc3.backends import SQLite\n",
"import seaborn as sns\n",
"import scipy as sp\n",
"import pdb\n",
"\n",
"# Helper functions\n",
"def indexall(L):\n",
" poo = []\n",
" for p in L:\n",
" if not p in poo:\n",
" poo.append(p)\n",
" Ix = np.array([poo.index(p) for p in L])\n",
" return poo,Ix\n",
"\n",
"# Helper functions\n",
"def indexall_B(L,B):\n",
" poo = []\n",
" for p in L:\n",
" if not p in poo:\n",
" poo.append(p)\n",
" Ix = np.array([poo.index(p) for p in L])\n",
" a, b = poo.index(B), 0\n",
" poo[b], poo[a] = poo[a], poo[b]\n",
" \n",
" Ix[Ix==b] = -1\n",
" Ix[Ix==a] = 0\n",
" Ix[Ix==-1] = a\n",
" return poo,Ix\n",
"\n",
"def subindexall(short,long):\n",
" poo = []\n",
" out = []\n",
" for s,l in zip(short,long):\n",
" if not l in poo:\n",
" poo.append(l)\n",
" out.append(s)\n",
" return indexall(out)\n",
"\n",
"match = lambda a, b: np.array([ b.index(x) if x in b else None for x in a ])\n",
"grep = lambda s, l: np.array([i for i in l if s in i])\n",
"\n",
"# Function to standardize covariates\n",
"def stdize(x):\n",
" return (x-np.mean(x))/(2*np.std(x))\n",
"\n",
"# Coefficient of variation\n",
"cv = lambda x: np.var(x) / np.mean(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Skittles in the bag: modeling with the Dirichlet-multinomial model\n",
"\n",
"The data for this example consist of count data from 468 packs of Skittles. For each bag, the author counted up how many of each color were present ([more here](https://possiblywrong.wordpress.com/2019/04/06/follow-up-i-found-two-identical-packs-of-skittles-among-468-packs-with-a-total-of-27740-skittles/))."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Strawberry</th>\n",
" <th>Orange</th>\n",
" <th>Lemon</th>\n",
" <th>Apple</th>\n",
" <th>Grape</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>10</td>\n",
" <td>15</td>\n",
" <td>11</td>\n",
" <td>7</td>\n",
" <td>18</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>12</td>\n",
" <td>17</td>\n",
" <td>15</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>16</td>\n",
" <td>11</td>\n",
" <td>15</td>\n",
" <td>11</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>15</td>\n",
" <td>8</td>\n",
" <td>13</td>\n",
" <td>16</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>11</td>\n",
" <td>14</td>\n",
" <td>20</td>\n",
" <td>8</td>\n",
" <td>7</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Strawberry Orange Lemon Apple Grape\n",
"0 10 15 11 7 18\n",
"1 5 12 17 15 10\n",
"2 16 11 15 11 9\n",
"3 15 8 13 16 7\n",
"4 11 14 20 8 7"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Colour scheme\n",
"skittles_palette = [\"red\", \"orange\", \"yellow\", \"green\", \"purple\"]\n",
"# Skittles data\n",
"sdata = pd.read_csv(\"skittles.txt\", sep=\"\\t\").drop(\"Uncounted\", axis=1, errors=\"ignore\")\n",
"sdata.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(468, 5)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sdata.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note, due to variance in individual sizes, there's no guarantee how many skittles you get per bag:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"counts = sdata.sum(axis=\"columns\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOGklEQVR4nO3cX4xc5X3G8e9TnCCFEBXKglzb7QJyo8JFTbSirZAiWqRAIIrJBZWRGlkVkrkwUlBTtYabcGOJViGpKhUk86ex2gTqJkFYJUpDaaQ0N5A1RYBxEC44sNi1l9IW0gsqzK8Xe9xM7Rnvn9nxMG++H2k157znPXN+r1758dl3Z06qCklSW35h3AVIklaf4S5JDTLcJalBhrskNchwl6QGrRl3AQAXXHBBTU9Pj7sMSZoo+/bte7Oqpvod+0CE+/T0NLOzs+MuQ5ImSpKfDDrmsowkNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXoA/ENVemDbHrH42O57qG7bxjLddUG79wlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktSgRcM9yYYk309yIMn+JF/o2u9K8kaSZ7uf63vOuSPJwSQvJbl2lAOQJJ1qzRL6vAd8saqeSXIusC/JE92xr1bVl3s7J7kM2AJcDvwy8I9Jfq2qjq9m4ZKkwRa9c6+qI1X1TLf9DnAAWHeaUzYDj1TVu1X1KnAQuHI1ipUkLc2y1tyTTANXAE91TbcleS7JQ0nO69rWAa/3nDZHn/8MkmxLMptkdn5+ftmFS5IGW3K4J/ko8C3g9qp6G7gPuBTYBBwB7jnRtc/pdUpD1a6qmqmqmampqWUXLkkabEnhnuRDLAT716vq2wBVdbSqjlfV+8D9/GzpZQ7Y0HP6euDw6pUsSVrMUj4tE+BB4EBVfaWnfW1Pt88BL3Tbe4EtSc5OcjGwEXh69UqWJC1mKZ+WuQr4PPB8kme7tjuBm5NsYmHJ5RBwK0BV7U+yB3iRhU/abPeTMpJ0Zi0a7lX1Q/qvo3/nNOfsBHYOUZckaQh+Q1WSGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJatCi4Z5kQ5LvJzmQZH+SL3Tt5yd5IsnL3et5PefckeRgkpeSXDvKAUiSTrWUO/f3gC9W1a8DvwVsT3IZsAN4sqo2Ak92+3THtgCXA9cB9yY5axTFS5L6WzTcq+pIVT3Tbb8DHADWAZuB3V233cCN3fZm4JGqereqXgUOAleuduGSpMGWteaeZBq4AngKuKiqjsDCfwDAhV23dcDrPafNdW0nv9e2JLNJZufn55dfuSRpoCWHe5KPAt8Cbq+qt0/XtU9bndJQtauqZqpqZmpqaqllSJKWYEnhnuRDLAT716vq213z0SRru+NrgWNd+xywoef09cDh1SlXkrQUS/m0TIAHgQNV9ZWeQ3uBrd32VuCxnvYtSc5OcjGwEXh69UqWJC1mzRL6XAV8Hng+ybNd253A3cCeJLcArwE3AVTV/iR7gBdZ+KTN9qo6vuqVS5IGWjTcq+qH9F9HB7hmwDk7gZ1D1CVJGoLfUJWkBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDFg33JA8lOZbkhZ62u5K8keTZ7uf6nmN3JDmY5KUk146qcEnSYEu5c/8acF2f9q9W1abu5zsASS4DtgCXd+fcm+Ss1SpWkrQ0i4Z7Vf0AeGuJ77cZeKSq3q2qV4GDwJVD1CdJWoFh1txvS/Jct2xzXte2Dni9p89c13aKJNuSzCaZnZ+fH6IMSdLJVhru9wGXApuAI8A9XXv69K1+b1BVu6pqpqpmpqamVliGJKmfFYV7VR2tquNV9T5wPz9bepkDNvR0XQ8cHq5ESdJyrSjck6zt2f0ccOKTNHuBLUnOTnIxsBF4ergSJUnLtWaxDkkeBq4GLkgyB3wJuDrJJhaWXA4BtwJU1f4ke4AXgfeA7VV1fDSlS5IGWTTcq+rmPs0Pnqb/TmDnMEVJkobjN1QlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBi35DVfogmN7x+LhLkCaKd+6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUoEXDPclDSY4leaGn7fwkTyR5uXs9r+fYHUkOJnkpybWjKlySNNhS7ty/Blx3UtsO4Mmq2gg82e2T5DJgC3B5d869Sc5atWolSUuyaLhX1Q+At05q3gzs7rZ3Azf2tD9SVe9W1avAQeDKVapVkrREK11zv6iqjgB0rxd27euA13v6zXVtkqQzaLX/oJo+bdW3Y7ItyWyS2fn5+VUuQ5J+vq003I8mWQvQvR7r2ueADT391gOH+71BVe2qqpmqmpmamlphGZKkflYa7nuBrd32VuCxnvYtSc5OcjGwEXh6uBIlScu1ZrEOSR4GrgYuSDIHfAm4G9iT5BbgNeAmgKran2QP8CLwHrC9qo6PqHZJ0gCLhntV3Tzg0DUD+u8Edg5TlCRpOH5DVZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ1aM8zJSQ4B7wDHgfeqaibJ+cDfAtPAIeD3quo/hitTkrQcq3Hn/jtVtamqZrr9HcCTVbUReLLblySdQaNYltkM7O62dwM3juAakqTTGDbcC/hekn1JtnVtF1XVEYDu9cJ+JybZlmQ2yez8/PyQZUiSeg215g5cVVWHk1wIPJHkx0s9sap2AbsAZmZmasg6JEk9hrpzr6rD3esx4FHgSuBokrUA3euxYYuUJC3PisM9yTlJzj2xDXwKeAHYC2ztum0FHhu2SEnS8gyzLHMR8GiSE+/zjar6bpIfAXuS3AK8Btw0fJmSpOVYcbhX1SvAb/Rp/3fgmmGKkiQNx2+oSlKDDHdJapDhLkkNMtwlqUGGuyQ1aNhvqEoakekdj4/t2ofuvmFs19bq8M5dkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ3ywWFalnE+zErS0nnnLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDRpZuCe5LslLSQ4m2TGq60iSTjWScE9yFvCXwKeBy4Cbk1w2imtJkk41qmfLXAkcrKpXAJI8AmwGXhzFxcb1vJNDd98wlutKo/bz9m9qnM9MGtWYRxXu64DXe/bngN/s7ZBkG7Ct2/1pkpeGuN4FwJtDnL8i+dORX2Is4zoDWh0XtDu2MzKuM/Bv6mRjn68hx/yrgw6MKtzTp63+307VLmDXqlwsma2qmdV4rw8SxzV5Wh2b45o8o/qD6hywoWd/PXB4RNeSJJ1kVOH+I2BjkouTfBjYAuwd0bUkSScZybJMVb2X5DbgH4CzgIeqav8ortVZleWdDyDHNXlaHZvjmjCpqsV7SZImit9QlaQGGe6S1KCJDPckZyX5lyR/3+3fleSNJM92P9ePu8blSnIoyfNd/bNd2/lJnkjycvd63rjrXIkBY2thzn4xyTeT/DjJgSS/3cKcDRhXC/P18Z76n03ydpLbW5izfiZyzT3JHwIzwMeq6jNJ7gJ+WlVfHm9lK5fkEDBTVW/2tP0Z8FZV3d09n+e8qvqTcdW4UgPGdheTP2e7gX+uqge6T4V9BLiTCZ+zAeO6nQmfr17dI1LeYOHLlduZ8DnrZ+Lu3JOsB24AHhh3LWfAZmB3t70buHGMtahHko8BnwQeBKiq/6mq/2TC5+w042rNNcC/VtVPmPA5G2Tiwh34c+CPgfdPar8tyXNJHprQX6sK+F6Sfd2jGQAuqqojAN3rhWOrbjj9xgaTPWeXAPPAX3VLhA8kOYfJn7NB44LJnq+TbQEe7rYnfc76mqhwT/IZ4FhV7Tvp0H3ApcAm4Ahwz5mubRVcVVWfYOFJmtuTfHLcBa2ifmOb9DlbA3wCuK+qrgD+G2jh0daDxjXp8/V/uqWmzwJ/N+5aRmmiwh24Cvhst4b7CPC7Sf6mqo5W1fGqeh+4n4WnUk6UqjrcvR4DHmVhDEeTrAXoXo+Nr8KV6ze2BuZsDpirqqe6/W+yEIqTPmd9x9XAfPX6NPBMVR3t9id9zvqaqHCvqjuqan1VTbPwa9U/VdXvn5iYzueAF8ZS4AolOSfJuSe2gU+xMIa9wNau21bgsfFUuHKDxjbpc1ZV/wa8nuTjXdM1LDzSeqLnbNC4Jn2+TnIzP1uSgQmfs0Em8tMyAEmuBv6o+7TMX7Pw62IBh4BbT6yhTYIkl7BwRwsLvxZ/o6p2JvklYA/wK8BrwE1V9daYylyR04xtoucMIMkmFv6w/2HgFeAPWLhhmvQ56zeuv2DC5wsgyUdYeBz5JVX1X13bxP8762diw12SNNhELctIkpbGcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkN+l9qVgmGzeo5PgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(counts);"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Strawberry</th>\n",
" <th>Orange</th>\n",
" <th>Lemon</th>\n",
" <th>Apple</th>\n",
" <th>Grape</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.163934</td>\n",
" <td>0.245902</td>\n",
" <td>0.180328</td>\n",
" <td>0.114754</td>\n",
" <td>0.295082</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.084746</td>\n",
" <td>0.203390</td>\n",
" <td>0.288136</td>\n",
" <td>0.254237</td>\n",
" <td>0.169492</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.258065</td>\n",
" <td>0.177419</td>\n",
" <td>0.241935</td>\n",
" <td>0.177419</td>\n",
" <td>0.145161</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.254237</td>\n",
" <td>0.135593</td>\n",
" <td>0.220339</td>\n",
" <td>0.271186</td>\n",
" <td>0.118644</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.183333</td>\n",
" <td>0.233333</td>\n",
" <td>0.333333</td>\n",
" <td>0.133333</td>\n",
" <td>0.116667</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Strawberry Orange Lemon Apple Grape\n",
"0 0.163934 0.245902 0.180328 0.114754 0.295082\n",
"1 0.084746 0.203390 0.288136 0.254237 0.169492\n",
"2 0.258065 0.177419 0.241935 0.177419 0.145161\n",
"3 0.254237 0.135593 0.220339 0.271186 0.118644\n",
"4 0.183333 0.233333 0.333333 0.133333 0.116667"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"normalized = sdata.div(counts, axis=\"rows\")\n",
"normalized.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference: what are the true proportions for each color?\n",
"\n",
"We can assume that the manufacturer is motivated to control the average number of Skittles per bag for cost and packing reasons. What is less clear is whether they exert any intentionality over the proportion of colors.\n",
"\n",
"One thing we might want to look into: are the proportions for each color in a bag supposed to be the same or are any flavors purposely favored or disfavored? And if the differences are intentional, what are the intended proportions for each color?\n",
"\n",
"Eyeballing the plot above, we might hypothesize that the machines are trying to get about 20% of each color and that fluctuations are due to machinery imprecision. For example, it is possible that it is solely due to random fluctuations that we see slightly more lemon and and slightly less apple. The means give us our maximum likelihood estimate (MLE) for the \"true\" proportions:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Strawberry 0.201242\n",
"Orange 0.198415\n",
"Lemon 0.204962\n",
"Apple 0.191098\n",
"Grape 0.204283\n",
"dtype: float64"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"normalized.mean(axis=\"rows\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Choosing the Dirichlet-multinomial\n",
"\n",
"In order to build a Bayesian model, we need to come up with a believable data generating process, which is a story of how our data came to be. Here's one for the Skittles data:\n",
"\n",
"- The factory decides on a proportion of Skittles where each and they sum up to 1.\n",
"- The factory makes a large number of Skittles in these proportions and then mixes them all up into a big hopper.\n",
"- For each bag, the Skittles factory selects a number of Skittles to put in that bag.\n",
"- Then the factory draws Skittles from the big hopper and puts them into the bag, and it turns out that a count of each color is observed in the bag.\n",
"\n",
"This corresponds almost exactly to the Dirichlet-multinomial model, where the Dirichlet is the prior distribution over proportions and the multinomial is the distribution for counts of each group. (This is the generalization of the Beta-Binomial model which does the same but for only two groups, e.g. heads versus tails on flips of a biased coin.) In particular, having proportional amounts of Skittles in a hopper ready to dispense corresponds directly to the Pólya urn interpretation with an extemely large $K$.\n",
"\n",
"The Dirichlet gives us a distribution over simplexes (vectors that add up to one), and the multinomial gives us a distribution over an enumerated set of discrete outcomes. Together, we have a model for the proportions which we can condition on the observed data. Based on the question we're asking above, the parameter we are most interested in is the $p$ parameter of the multinomial which is sampled from the Dirichlet prior - or the idealized proportions for each color before they are probabilistically sampled into actual counts."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"color_names = sdata.columns\n",
"\n",
"with pm.Model(coords={'Colours':color_names}) as Skittles:\n",
" # Dirichlet prior - smaller values means we're less certain about equality of proportions\n",
" alpha = np.ones(len(color_names))\n",
" \n",
" # Estimated proportions\n",
" p = pm.Dirichlet(\"p\", a=alpha, dims='Colours')\n",
" \n",
" # Bag size\n",
" n = pm.DiscreteUniform(\"n\", lower=40, upper=100, observed=sdata.sum(axis=1).values)\n",
" \n",
" # choose how many of each color to put in that bag adding up to n based on proportions p\n",
" k = pm.Multinomial(\"k\", n=n, p=p, observed=sdata)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Strawberry</th>\n",
" <th>Orange</th>\n",
" <th>Lemon</th>\n",
" <th>Apple</th>\n",
" <th>Grape</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>10</td>\n",
" <td>15</td>\n",
" <td>11</td>\n",
" <td>7</td>\n",
" <td>18</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>12</td>\n",
" <td>17</td>\n",
" <td>15</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>16</td>\n",
" <td>11</td>\n",
" <td>15</td>\n",
" <td>11</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>15</td>\n",
" <td>8</td>\n",
" <td>13</td>\n",
" <td>16</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>11</td>\n",
" <td>14</td>\n",
" <td>20</td>\n",
" <td>8</td>\n",
" <td>7</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Strawberry Orange Lemon Apple Grape\n",
"0 10 15 11 7 18\n",
"1 5 12 17 15 10\n",
"2 16 11 15 11 9\n",
"3 15 8 13 16 7\n",
"4 11 14 20 8 7"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sdata.head()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [p]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='12000' class='' max='12000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [12000/12000 00:03<00:00 Sampling 4 chains, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 15 seconds.\n"
]
}
],
"source": [
"with Skittles:\n",
" trace = pm.sample(2000)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aaronmacneil/opt/anaconda3/lib/python3.8/site-packages/arviz/data/io_pymc3.py:87: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_mean</th>\n",
" <th>ess_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>p[0]</th>\n",
" <td>0.201</td>\n",
" <td>0.002</td>\n",
" <td>0.197</td>\n",
" <td>0.206</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>13470.0</td>\n",
" <td>13470.0</td>\n",
" <td>13457.0</td>\n",
" <td>6058.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>p[1]</th>\n",
" <td>0.198</td>\n",
" <td>0.002</td>\n",
" <td>0.194</td>\n",
" <td>0.203</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>12639.0</td>\n",
" <td>12627.0</td>\n",
" <td>12607.0</td>\n",
" <td>6292.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>p[2]</th>\n",
" <td>0.205</td>\n",
" <td>0.002</td>\n",
" <td>0.200</td>\n",
" <td>0.210</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>13383.0</td>\n",
" <td>13382.0</td>\n",
" <td>13360.0</td>\n",
" <td>6352.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>p[3]</th>\n",
" <td>0.191</td>\n",
" <td>0.002</td>\n",
" <td>0.187</td>\n",
" <td>0.195</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>13381.0</td>\n",
" <td>13378.0</td>\n",
" <td>13377.0</td>\n",
" <td>6703.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>p[4]</th>\n",
" <td>0.204</td>\n",
" <td>0.002</td>\n",
" <td>0.200</td>\n",
" <td>0.209</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>13719.0</td>\n",
" <td>13719.0</td>\n",
" <td>13693.0</td>\n",
" <td>6422.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd \\\n",
"p[0] 0.201 0.002 0.197 0.206 0.0 0.0 13470.0 13470.0 \n",
"p[1] 0.198 0.002 0.194 0.203 0.0 0.0 12639.0 12627.0 \n",
"p[2] 0.205 0.002 0.200 0.210 0.0 0.0 13383.0 13382.0 \n",
"p[3] 0.191 0.002 0.187 0.195 0.0 0.0 13381.0 13378.0 \n",
"p[4] 0.204 0.002 0.200 0.209 0.0 0.0 13719.0 13719.0 \n",
"\n",
" ess_bulk ess_tail r_hat \n",
"p[0] 13457.0 6058.0 1.0 \n",
"p[1] 12607.0 6292.0 1.0 \n",
"p[2] 13360.0 6352.0 1.0 \n",
"p[3] 13377.0 6703.0 1.0 \n",
"p[4] 13693.0 6422.0 1.0 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.summary(trace)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aaronmacneil/opt/anaconda3/lib/python3.8/site-packages/arviz/data/io_pymc3.py:87: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.\n",
" warnings.warn(\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"pm.plot_forest(trace);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Turns out out childhood intuitions about the 'green ones' were accounted for."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Strawberry Orange Lemon Apple Grape Uncounted
10 15 11 7 18 0
5 12 17 15 10 0
16 11 15 11 9 0
15 8 13 16 7 0
11 14 20 8 7 1
10 11 11 17 11 0
9 17 6 17 14 0
14 12 13 11 12 0
11 11 14 10 17 0
12 6 18 11 12 1
16 7 5 22 10 0
14 9 13 16 10 0
15 7 11 11 17 0
20 2 8 22 7 0
15 11 7 11 15 0
16 13 15 8 12 0
10 13 15 11 14 0
12 10 9 8 20 0
9 17 10 15 11 0
9 11 16 12 14 0
13 10 12 17 8 0
10 14 9 13 14 0
9 12 20 13 8 0
15 11 12 15 8 0
19 8 8 14 14 0
12 14 13 15 9 0
12 9 11 13 13 0
13 9 15 13 11 0
12 10 12 16 9 0
12 9 13 12 13 0
9 9 11 17 18 0
13 9 13 10 15 0
15 10 11 11 18 0
18 11 9 20 6 0
9 10 15 11 16 0
13 10 9 12 14 0
5 22 14 10 8 0
12 13 9 10 16 0
14 9 9 11 18 0
11 18 8 8 17 1
8 13 11 15 15 0
14 13 12 10 10 0
12 9 14 8 16 0
16 15 13 3 11 0
13 12 14 9 14 0
8 13 16 11 12 0
15 9 15 7 16 1
16 10 17 12 6 0
9 12 11 10 18 0
14 8 19 7 13 0
13 12 15 8 13 0
22 10 12 9 7 0
6 15 16 13 13 0
18 14 11 6 12 0
11 13 15 10 14 1
15 10 10 9 18 0
9 11 24 6 10 0
10 11 13 8 16 0
11 12 15 8 17 0
13 14 12 11 10 1
13 12 10 8 18 0
14 15 12 4 17 0
14 14 11 10 12 0
11 5 19 12 14 0
12 9 14 6 18 0
12 9 17 6 15 0
12 15 17 9 7 0
4 16 16 9 14 0
13 12 11 4 21 0
14 14 10 11 12 0
10 17 8 10 14 0
12 12 14 7 18 0
12 15 12 11 10 0
13 9 9 17 12 0
9 12 8 16 13 0
8 10 11 13 17 0
7 16 11 15 10 0
11 12 18 11 9 0
9 14 12 12 13 0
6 11 12 12 15 1
12 9 9 13 15 0
9 15 11 15 9 0
7 13 10 13 16 0
6 15 15 13 11 0
11 10 8 16 14 0
7 16 13 8 15 0
13 10 16 17 5 0
8 12 13 11 16 1
7 18 14 8 11 0
7 15 14 16 7 0
13 12 9 13 15 0
9 13 16 10 12 0
8 13 15 9 14 0
11 12 14 11 11 0
13 15 14 9 7 0
11 19 14 9 7 0
13 20 15 5 4 0
12 8 16 13 9 1
8 15 12 15 10 0
11 12 12 12 13 0
4 16 15 15 10 0
12 8 17 15 9 0
7 13 14 13 11 1
14 14 9 11 11 0
13 17 6 14 12 0
10 9 15 16 10 0
10 13 13 15 9 0
14 5 17 13 10 0
16 12 9 9 14 0
16 11 13 10 14 0
12 11 17 11 9 1
15 8 10 12 16 0
15 10 17 10 8 1
9 13 10 19 10 0
20 13 10 11 6 0
16 16 13 9 7 0
17 11 13 7 12 0
15 8 14 13 10 0
5 12 8 16 20 0
11 10 19 8 12 0
16 13 11 12 8 0
18 14 4 14 12 0
10 14 13 15 7 0
13 11 14 11 12 0
15 13 17 6 10 0
14 6 16 11 14 0
10 9 14 15 11 1
15 12 9 12 13 0
13 11 12 12 12 1
13 6 17 7 16 0
12 13 12 11 11 0
10 17 13 7 13 0
13 11 10 13 12 0
12 12 16 12 9 0
16 10 12 10 13 0
12 11 9 10 16 0
15 13 12 13 8 0
17 13 11 10 9 0
15 8 13 10 13 0
15 8 16 8 12 0
9 13 16 9 12 0
14 10 10 14 13 0
8 11 15 10 13 0
11 16 8 15 10 1
10 11 10 16 14 0
12 9 12 15 11 0
14 11 11 14 8 0
18 8 9 10 14 0
14 10 8 14 12 0
8 11 16 10 13 1
19 10 5 11 13 1
15 9 7 14 16 0
18 12 9 12 7 0
13 13 7 13 15 0
11 9 8 13 19 0
17 8 12 10 10 0
15 9 13 12 11 0
15 10 11 12 11 0
18 8 17 9 6 0
10 16 10 9 15 0
12 7 14 15 9 1
18 9 6 13 14 0
14 11 7 7 19 1
13 11 16 10 9 0
13 10 8 14 13 1
15 11 18 10 7 0
18 9 8 14 11 0
15 8 13 10 12 0
18 13 7 9 12 0
13 16 5 13 11 1
12 13 14 7 12 0
10 9 9 18 12 1
12 8 15 15 9 0
14 7 12 6 21 0
17 8 11 11 11 0
17 8 8 12 15 0
15 4 13 13 15 0
15 15 5 12 13 0
12 7 13 20 9 0
16 9 10 10 14 0
10 13 7 11 15 0
11 11 10 12 14 0
14 18 10 7 7 0
16 13 10 9 11 0
12 11 11 11 13 0
10 11 13 14 10 0
8 14 9 14 13 0
14 13 8 9 12 0
17 15 9 9 7 0
13 15 13 7 11 0
12 12 8 15 11 0
12 14 10 13 9 1
10 16 16 6 12 0
13 8 12 8 15 0
14 16 9 15 6 0
14 12 10 15 7 0
9 12 13 6 17 0
11 11 12 12 11 0
14 9 9 10 16 0
8 18 13 10 8 0
8 16 14 13 6 0
16 10 10 10 12 0
19 11 11 5 13 1
15 7 14 6 17 0
11 15 7 13 12 0
10 15 13 11 11 0
16 14 7 10 10 0
16 12 12 3 17 0
14 14 11 9 11 1
13 17 20 6 5 0
9 17 10 10 13 0
8 14 9 9 17 1
6 20 14 9 9 0
11 14 15 9 10 0
12 14 13 10 11 0
17 7 14 11 9 0
15 14 8 7 13 0
15 14 15 7 7 0
15 9 10 9 14 0
12 15 10 8 14 0
11 5 10 13 17 0
15 13 10 12 8 0
15 11 17 8 5 0
18 11 10 9 9 1
11 17 11 9 10 0
15 11 12 10 11 0
8 14 12 10 13 0
12 11 10 17 10 1
13 13 13 9 11 0
14 11 6 14 11 0
12 20 7 9 11 0
8 13 15 10 12 0
14 11 7 11 14 0
19 9 12 7 13 0
12 15 14 7 11 0
17 14 12 9 5 0
13 14 12 11 8 0
9 12 12 8 15 1
9 12 9 7 22 0
13 16 11 8 10 0
13 14 8 7 15 1
8 17 15 10 10 0
10 18 10 12 10 0
14 10 11 12 11 0
9 14 11 10 14 0
12 11 17 8 12 0
10 15 8 7 18 1
12 13 17 9 9 0
15 5 12 16 12 0
14 10 18 8 11 0
12 11 9 9 19 0
14 17 10 11 7 0
9 8 12 11 18 0
12 8 13 10 15 0
10 11 11 13 11 1
12 12 8 12 12 0
13 9 15 13 10 0
15 12 9 9 13 0
17 15 11 6 14 1
13 10 13 12 6 1
12 14 12 9 12 0
13 7 15 15 11 0
8 7 13 14 17 0
11 16 7 15 11 0
11 9 13 9 16 0
8 17 11 13 9 0
15 9 15 8 14 0
10 9 14 10 15 1
11 10 18 10 10 0
7 14 12 16 10 0
8 16 13 9 13 0
7 8 15 13 14 0
9 17 10 10 14 0
14 5 14 11 15 0
13 13 9 11 13 0
16 11 17 9 6 0
7 11 12 14 16 0
14 5 15 12 12 0
18 10 17 8 6 0
12 12 14 10 10 1
14 9 7 14 16 0
15 10 7 14 15 0
12 15 11 10 11 0
13 7 16 12 10 0
12 12 10 12 14 0
14 5 13 8 19 1
12 10 10 14 12 0
8 12 11 16 15 0
14 8 17 11 9 0
14 10 8 15 11 0
10 6 7 13 9 1
14 10 18 19 12 1
13 11 13 11 8 1
14 14 15 11 3 0
9 13 8 16 13 0
13 14 14 12 8 0
8 17 6 15 13 0
11 10 17 14 7 1
13 8 12 11 15 0
12 15 5 16 11 0
17 15 9 8 11 0
12 14 11 12 11 0
11 9 13 17 9 0
9 12 13 13 10 1
9 14 10 15 11 0
16 10 10 11 11 0
13 9 12 14 13 0
13 9 12 12 14 0
9 10 12 22 6 0
18 9 12 11 9 0
13 10 7 17 15 0
13 8 15 14 13 0
10 10 11 17 13 1
16 7 11 14 12 0
13 14 15 10 7 0
8 9 10 20 12 0
18 4 14 14 9 0
12 19 9 9 12 0
12 8 8 15 15 0
17 6 10 18 8 0
7 11 17 17 10 0
7 12 16 17 8 0
12 9 17 14 9 0
12 9 9 16 14 0
13 6 13 14 14 0
17 8 11 12 13 0
8 11 12 14 12 0
7 9 16 11 14 0
6 12 16 15 12 0
8 11 10 16 14 0
10 13 16 11 8 0
13 18 11 8 8 0
9 12 11 11 15 0
11 11 12 13 11 0
10 9 10 16 13 1
8 13 21 8 8 0
10 9 15 11 14 0
8 16 14 13 7 0
14 11 13 8 12 0
8 10 22 8 12 1
5 17 14 11 12 0
13 15 11 12 7 0
11 7 10 16 15 0
12 9 12 14 12 1
8 11 16 9 15 0
10 15 13 5 14 0
12 12 12 13 10 0
15 14 10 11 9 0
12 13 11 9 13 1
8 11 11 17 12 0
11 11 14 11 10 0
8 11 16 13 13 0
15 8 14 13 9 0
12 8 10 13 17 0
15 7 12 10 16 0
8 4 14 15 19 0
15 12 12 9 12 0
10 13 13 11 12 0
14 12 13 5 17 0
7 13 10 13 14 0
8 11 12 12 16 1
13 8 15 11 13 0
7 12 12 11 19 0
18 11 14 9 8 0
17 7 14 12 10 0
9 16 12 11 11 0
10 11 16 14 8 0
11 15 13 12 10 0
13 13 8 11 17 0
3 12 19 16 10 0
6 6 17 18 13 0
9 16 11 10 14 0
9 8 17 14 13 0
13 14 11 8 13 0
12 10 11 16 9 0
13 17 10 11 8 1
13 13 11 10 12 0
8 6 22 13 10 0
14 10 14 8 14 1
14 13 10 8 14 0
10 11 20 10 9 0
10 10 19 12 8 0
11 15 11 12 10 1
8 13 14 14 13 0
9 11 17 6 15 0
12 12 11 8 15 0
10 11 13 12 12 0
10 10 15 12 11 1
7 9 12 13 19 0
12 9 13 11 14 0
10 18 10 10 11 0
8 10 15 13 13 0
6 13 19 9 14 0
7 12 11 7 21 0
6 15 12 9 15 0
13 12 10 11 14 0
16 13 11 6 11 0
14 12 12 16 6 0
7 15 15 13 9 0
8 16 11 9 15 0
9 10 15 11 15 0
10 18 9 9 15 0
10 10 12 14 13 0
9 9 11 18 11 0
13 7 10 13 15 0
10 15 11 10 13 0
11 12 15 11 10 0
14 20 8 8 11 0
12 15 15 7 11 0
10 6 11 13 18 0
8 10 12 12 18 0
11 16 13 14 5 0
13 9 14 8 15 0
13 11 11 9 16 0
12 13 20 5 10 0
14 16 12 8 5 0
11 12 11 11 12 0
19 13 11 7 11 0
16 10 8 8 16 0
8 11 10 10 16 1
10 11 10 9 21 0
16 8 13 8 16 0
8 13 6 7 24 0
6 17 17 7 10 0
9 18 10 11 10 0
5 17 15 10 9 1
12 13 12 9 13 0
17 17 9 10 8 0
8 10 12 16 11 0
14 11 9 12 13 0
11 13 16 8 10 0
11 13 10 10 15 0
16 14 10 14 5 1
10 19 13 2 15 0
11 13 10 12 12 0
9 14 9 15 12 0
11 8 15 14 12 0
12 9 12 14 13 0
6 10 16 9 17 0
10 14 8 10 17 0
10 15 9 12 12 0
7 18 14 9 11 0
19 9 10 9 12 0
12 10 10 13 13 0
10 12 13 15 8 0
15 10 12 10 12 0
7 15 11 9 16 0
10 17 10 9 12 0
13 6 13 17 8 0
10 16 15 9 10 1
13 13 7 12 14 0
11 10 13 7 17 0
13 10 2 17 15 0
14 17 13 8 8 0
10 19 9 9 10 0
13 13 9 9 14 0
9 10 14 13 10 0
17 10 10 10 11 0
9 8 17 12 13 0
14 12 15 8 11 0
15 11 10 10 14 0
11 14 5 15 13 0
10 14 14 8 12 0
11 11 12 13 11 0
17 10 8 11 12 0
9 14 12 10 15 0
12 14 11 10 10 0
11 8 12 13 15 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment