Skip to content

Instantly share code, notes, and snippets.

@dinaber
Last active March 24, 2020 20:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dinaber/0e8e12639a4a82f4a0faa11ad7774bde to your computer and use it in GitHub Desktop.
Save dinaber/0e8e12639a4a82f4a0faa11ad7774bde to your computer and use it in GitHub Desktop.
pooling_efficiency.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Worst-case calculation:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" <th>25</th>\n",
" <th>26</th>\n",
" <th>27</th>\n",
" <th>28</th>\n",
" <th>29</th>\n",
" <th>30</th>\n",
" <th>31</th>\n",
" <th>32</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>3.0</td>\n",
" <td>3.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1.0</td>\n",
" <td>7.0</td>\n",
" <td>11.0</td>\n",
" <td>13.0</td>\n",
" <td>15.0</td>\n",
" <td>15.0</td>\n",
" <td>15.0</td>\n",
" <td>15.0</td>\n",
" <td>15.0</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>1.0</td>\n",
" <td>9.0</td>\n",
" <td>15.0</td>\n",
" <td>19.0</td>\n",
" <td>23.0</td>\n",
" <td>25.0</td>\n",
" <td>27.0</td>\n",
" <td>29.0</td>\n",
" <td>31.0</td>\n",
" <td>31.0</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>1.0</td>\n",
" <td>11.0</td>\n",
" <td>19.0</td>\n",
" <td>25.0</td>\n",
" <td>31.0</td>\n",
" <td>35.0</td>\n",
" <td>39.0</td>\n",
" <td>43.0</td>\n",
" <td>47.0</td>\n",
" <td>49.0</td>\n",
" <td>...</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6 rows × 33 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 7 8 9 ... 23 \\\n",
"1 1.0 1.0 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
"2 1.0 3.0 3.0 NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
"4 1.0 5.0 7.0 7.0 7.0 NaN NaN NaN NaN NaN ... NaN \n",
"8 1.0 7.0 11.0 13.0 15.0 15.0 15.0 15.0 15.0 NaN ... NaN \n",
"16 1.0 9.0 15.0 19.0 23.0 25.0 27.0 29.0 31.0 31.0 ... NaN \n",
"32 1.0 11.0 19.0 25.0 31.0 35.0 39.0 43.0 47.0 49.0 ... 63.0 \n",
"\n",
" 24 25 26 27 28 29 30 31 32 \n",
"1 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"2 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"4 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"8 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"16 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"32 63.0 63.0 63.0 63.0 63.0 63.0 63.0 63.0 63.0 \n",
"\n",
"[6 rows x 33 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from scipy.special import comb\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"# k - the degree. 2^k is the number of samples.\n",
"# i - number of sick people\n",
"# G - cost\n",
"\n",
"MAX_K = 5\n",
"MAX_S = 2 ** MAX_K\n",
"G = np.full((MAX_K + 1, MAX_S + 1), fill_value=np.NaN) # rows are the size of S, columns are the number of sick people in the group\n",
"G[:, 0] = 1 # when there are no infected people, exactly 1 test is needed\n",
"G[0, :2] = 1 # When there is only one person the group, every feasible number of infected people (0 or 1) will be detected with a single test\n",
"\n",
"for k in range(1, MAX_K + 1): # start from the smallest groups that are not trivial (2 people)\n",
" s = 2 ** k\n",
" half_s = int(s//2)\n",
" for i in range(1, s + 1): # test for non trivial number of infected people\n",
" min_j = max((0, i-half_s))\n",
" max_j = min((i, half_s))\n",
" G[k, i] = 1 + max(G[k-1, j] + G[k-1, i - j] for j in range(min_j, max_j+1))\n",
"\n",
"\n",
"def prob(k, i, p):\n",
" return comb(2 ** k, i) * (p ** i) * (1 - p) ** (2 ** k - i)\n",
"\n",
"\n",
"def exp(k, p):\n",
" return sum(prob(k, i, p) * G[k, i] for i in range(0, 2 ** k + 1))\n",
"\n",
"# for \"test every positive pool\" strategy:\n",
"# The chanves that at least one is infected times the pool size\n",
"def all_exp(k, p):\n",
" return 1 + ((1 - (1-p)**(2 ** k))*2 ** k)\n",
"\n",
"\n",
"pd.DataFrame(G, index=[1,2,4,8,16,32])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ps = [0.001,0.01,0.02, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5, 1]\n",
"E = np.zeros([6, len(ps)])\n",
"for k in range(0, 6):\n",
" for ix, p in enumerate(ps):\n",
" E[k,ix] = exp(k, p)* 2**(5-k)\n",
"E_df = pd.DataFrame(E, index=[1,2,4,8,16,32], columns=ps)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0.001</th>\n",
" <th>0.01</th>\n",
" <th>0.02</th>\n",
" <th>0.05</th>\n",
" <th>0.1</th>\n",
" <th>0.15</th>\n",
" <th>0.2</th>\n",
" <th>0.3</th>\n",
" <th>0.5</th>\n",
" <th>1.0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>16.063968</td>\n",
" <td>16.636800</td>\n",
" <td>17.267200</td>\n",
" <td>19.120000</td>\n",
" <td>22.080000</td>\n",
" <td>24.880000</td>\n",
" <td>27.520000</td>\n",
" <td>32.320000</td>\n",
" <td>40.000000</td>\n",
" <td>48.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>8.127904</td>\n",
" <td>9.270400</td>\n",
" <td>10.521603</td>\n",
" <td>14.160100</td>\n",
" <td>19.841600</td>\n",
" <td>25.048100</td>\n",
" <td>29.785600</td>\n",
" <td>37.889600</td>\n",
" <td>49.000000</td>\n",
" <td>56.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>4.191776</td>\n",
" <td>5.897605</td>\n",
" <td>7.750485</td>\n",
" <td>13.043104</td>\n",
" <td>21.004040</td>\n",
" <td>27.957821</td>\n",
" <td>33.995315</td>\n",
" <td>43.670319</td>\n",
" <td>55.281250</td>\n",
" <td>60.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>2.255520</td>\n",
" <td>4.512068</td>\n",
" <td>6.929011</td>\n",
" <td>13.632170</td>\n",
" <td>23.175155</td>\n",
" <td>31.031972</td>\n",
" <td>37.529379</td>\n",
" <td>47.431678</td>\n",
" <td>58.804749</td>\n",
" <td>62.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>1.319008</td>\n",
" <td>4.101410</td>\n",
" <td>7.011544</td>\n",
" <td>14.733220</td>\n",
" <td>25.025670</td>\n",
" <td>33.104152</td>\n",
" <td>39.655088</td>\n",
" <td>49.367411</td>\n",
" <td>60.758026</td>\n",
" <td>63.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0.001 0.010 0.020 0.050 0.100 0.150 \\\n",
"1 32.000000 32.000000 32.000000 32.000000 32.000000 32.000000 \n",
"2 16.063968 16.636800 17.267200 19.120000 22.080000 24.880000 \n",
"4 8.127904 9.270400 10.521603 14.160100 19.841600 25.048100 \n",
"8 4.191776 5.897605 7.750485 13.043104 21.004040 27.957821 \n",
"16 2.255520 4.512068 6.929011 13.632170 23.175155 31.031972 \n",
"32 1.319008 4.101410 7.011544 14.733220 25.025670 33.104152 \n",
"\n",
" 0.200 0.300 0.500 1.000 \n",
"1 32.000000 32.000000 32.000000 32.0 \n",
"2 27.520000 32.320000 40.000000 48.0 \n",
"4 29.785600 37.889600 49.000000 56.0 \n",
"8 33.995315 43.670319 55.281250 60.0 \n",
"16 37.529379 47.431678 58.804749 62.0 \n",
"32 39.655088 49.367411 60.758026 63.0 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(E_df)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# For pool size 32 also show the \"test every positive pool\" strategy:\n",
"k = 5\n",
"test_all = np.zeros([1, len(ps)])\n",
"for ix, p in enumerate(ps):\n",
" test_all[0][ix] = all_exp(k, p)* 2**(5-k)\n",
"test_all_df = pd.DataFrame(test_all, index=['test_all'], columns=ps)\n",
"E_df = pd.concat([E_df, test_all_df])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x227756189e8>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x864 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tE_df = E_df.transpose()\n",
"ax = tE_df.plot(kind=\"bar\", figsize=(15,12), fontsize=22, colormap='tab20')\n",
"ax.set_xlabel(\"p - probability of infection\", fontsize=22)\n",
"ax.set_ylabel(\"E(p,s) for 32 samples\", fontsize=22)\n",
"\n",
"bars = ax.patches\n",
"hatches = ''.join(h*len(tE_df) for h in ' x')\n",
"\n",
"for bar, hatch in zip(bars, hatches):\n",
" bar.set_hatch(hatch)\n",
"\n",
"ax.legend(title=\"pool size\", prop=dict(size=22), title_fontsize='22')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Average case calculation:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" <th>25</th>\n",
" <th>26</th>\n",
" <th>27</th>\n",
" <th>28</th>\n",
" <th>29</th>\n",
" <th>30</th>\n",
" <th>31</th>\n",
" <th>32</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>3.0</td>\n",
" <td>3.000000</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>6.333333</td>\n",
" <td>7.000000</td>\n",
" <td>7.000000</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1.0</td>\n",
" <td>7.0</td>\n",
" <td>9.857143</td>\n",
" <td>11.857143</td>\n",
" <td>13.228571</td>\n",
" <td>14.142857</td>\n",
" <td>14.714286</td>\n",
" <td>15.000000</td>\n",
" <td>15.000000</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>1.0</td>\n",
" <td>9.0</td>\n",
" <td>13.533333</td>\n",
" <td>17.057143</td>\n",
" <td>19.870330</td>\n",
" <td>22.164835</td>\n",
" <td>24.062937</td>\n",
" <td>25.643357</td>\n",
" <td>26.958664</td>\n",
" <td>28.046154</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>1.0</td>\n",
" <td>11.0</td>\n",
" <td>17.322581</td>\n",
" <td>22.522581</td>\n",
" <td>26.936374</td>\n",
" <td>30.770857</td>\n",
" <td>34.156841</td>\n",
" <td>37.180466</td>\n",
" <td>39.901350</td>\n",
" <td>42.362932</td>\n",
" <td>...</td>\n",
" <td>60.62135</td>\n",
" <td>61.162402</td>\n",
" <td>61.629588</td>\n",
" <td>62.025584</td>\n",
" <td>62.352614</td>\n",
" <td>62.612458</td>\n",
" <td>62.806452</td>\n",
" <td>62.935484</td>\n",
" <td>63.0</td>\n",
" <td>63.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6 rows × 33 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 \\\n",
"1 1.0 1.0 NaN NaN NaN NaN NaN \n",
"2 1.0 3.0 3.000000 NaN NaN NaN NaN \n",
"4 1.0 5.0 6.333333 7.000000 7.000000 NaN NaN \n",
"8 1.0 7.0 9.857143 11.857143 13.228571 14.142857 14.714286 \n",
"16 1.0 9.0 13.533333 17.057143 19.870330 22.164835 24.062937 \n",
"32 1.0 11.0 17.322581 22.522581 26.936374 30.770857 34.156841 \n",
"\n",
" 7 8 9 ... 23 24 25 \\\n",
"1 NaN NaN NaN ... NaN NaN NaN \n",
"2 NaN NaN NaN ... NaN NaN NaN \n",
"4 NaN NaN NaN ... NaN NaN NaN \n",
"8 15.000000 15.000000 NaN ... NaN NaN NaN \n",
"16 25.643357 26.958664 28.046154 ... NaN NaN NaN \n",
"32 37.180466 39.901350 42.362932 ... 60.62135 61.162402 61.629588 \n",
"\n",
" 26 27 28 29 30 31 32 \n",
"1 NaN NaN NaN NaN NaN NaN NaN \n",
"2 NaN NaN NaN NaN NaN NaN NaN \n",
"4 NaN NaN NaN NaN NaN NaN NaN \n",
"8 NaN NaN NaN NaN NaN NaN NaN \n",
"16 NaN NaN NaN NaN NaN NaN NaN \n",
"32 62.025584 62.352614 62.612458 62.806452 62.935484 63.0 63.0 \n",
"\n",
"[6 rows x 33 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from scipy.special import comb\n",
"from matplotlib import pyplot as plt\n",
"from math import factorial\n",
"%matplotlib inline\n",
"# k - the degree. 2^k is the number of samples.\n",
"# i - number of sick people\n",
"# G - cost\n",
"\n",
"MAX_K = 5\n",
"MAX_S = 2 ** MAX_K\n",
"G = np.full((MAX_K + 1, MAX_S + 1), fill_value=np.NaN) # rows are the size of S, columns are the number of sick people in the group\n",
"G[:, 0] = 1 # when there are no sick people, exactly 1 test is needed\n",
"G[0, :2] = 1 # When there is only one person the group, every feasible number of sick people (0 or 1) will be detected with a single test\n",
"\n",
"def P(k, j, i):\n",
" # 2 groups, 2^k people in each group, i sick people in total, computes the proabablity of j sick people in the left group\n",
" s = 2**k\n",
" N = 2*s\n",
" if j>min(s,i) or j<max(0,i-s): \n",
" return 0\n",
" nom = (factorial(s) ** 2) * factorial(N-i)\n",
" denom = factorial(N) * factorial(s-j) * factorial(s-i+j)\n",
" prob = comb(i, j) * nom/denom\n",
" return prob\n",
" \n",
"for k in range(1, MAX_K + 1): # start from the smallest groups that are not trivial (2 people)\n",
" s = 2 ** k\n",
" half_s = int(s//2)\n",
" for i in range(1, s + 1): # test for non trivial number of sick people\n",
" min_j = max((0, i-half_s))\n",
" max_j = min((i, half_s))\n",
" G[k, i] = 1 + sum(P(k-1, j, i) * (G[k-1, j] + G[k-1, i - j]) for j in range(min_j, max_j+1))\n",
"\n",
"\n",
"def prob(k, i, p):\n",
" return comb(2 ** k, i) * (p ** i) * (1 - p) ** (2 ** k - i)\n",
"\n",
"\n",
"def avg_exp(k, p):\n",
" return sum(prob(k, i, p) * G[k, i] for i in range(0, 2 ** k + 1))\n",
"\n",
"pd.DataFrame(G, index=[1,2,4,8,16,32])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"ps = [0.001,0.01,0.02, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5, 1]\n",
"E = np.zeros([6, len(ps)])\n",
"for k in range(0, 6):\n",
" for ix, p in enumerate(ps):\n",
" E[k,ix] = avg_exp(k, p)* 2**(5-k)\n",
"avg_E_df = pd.DataFrame(E, index=[1,2,4,8,16,32], columns=ps)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0.001</th>\n",
" <th>0.01</th>\n",
" <th>0.02</th>\n",
" <th>0.05</th>\n",
" <th>0.1</th>\n",
" <th>0.15</th>\n",
" <th>0.2</th>\n",
" <th>0.3</th>\n",
" <th>0.5</th>\n",
" <th>1.0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.000000</td>\n",
" <td>32.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>16.063968</td>\n",
" <td>16.636800</td>\n",
" <td>17.267200</td>\n",
" <td>19.120000</td>\n",
" <td>22.080000</td>\n",
" <td>24.880000</td>\n",
" <td>27.520000</td>\n",
" <td>32.320000</td>\n",
" <td>40.000000</td>\n",
" <td>48.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>8.127872</td>\n",
" <td>9.267264</td>\n",
" <td>10.509309</td>\n",
" <td>14.087900</td>\n",
" <td>19.582400</td>\n",
" <td>24.527900</td>\n",
" <td>28.966400</td>\n",
" <td>36.478400</td>\n",
" <td>47.000000</td>\n",
" <td>56.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>4.191649</td>\n",
" <td>5.885306</td>\n",
" <td>7.703205</td>\n",
" <td>12.780537</td>\n",
" <td>20.138662</td>\n",
" <td>26.347976</td>\n",
" <td>31.624223</td>\n",
" <td>40.017216</td>\n",
" <td>50.968750</td>\n",
" <td>60.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>2.255171</td>\n",
" <td>4.479475</td>\n",
" <td>6.808014</td>\n",
" <td>13.020030</td>\n",
" <td>21.397454</td>\n",
" <td>28.050971</td>\n",
" <td>33.511633</td>\n",
" <td>42.003923</td>\n",
" <td>52.968689</td>\n",
" <td>62.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>1.318189</td>\n",
" <td>4.029515</td>\n",
" <td>6.760248</td>\n",
" <td>13.632607</td>\n",
" <td>22.328781</td>\n",
" <td>29.039945</td>\n",
" <td>34.510048</td>\n",
" <td>43.003901</td>\n",
" <td>53.968689</td>\n",
" <td>63.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0.001 0.010 0.020 0.050 0.100 0.150 \\\n",
"1 32.000000 32.000000 32.000000 32.000000 32.000000 32.000000 \n",
"2 16.063968 16.636800 17.267200 19.120000 22.080000 24.880000 \n",
"4 8.127872 9.267264 10.509309 14.087900 19.582400 24.527900 \n",
"8 4.191649 5.885306 7.703205 12.780537 20.138662 26.347976 \n",
"16 2.255171 4.479475 6.808014 13.020030 21.397454 28.050971 \n",
"32 1.318189 4.029515 6.760248 13.632607 22.328781 29.039945 \n",
"\n",
" 0.200 0.300 0.500 1.000 \n",
"1 32.000000 32.000000 32.000000 32.0 \n",
"2 27.520000 32.320000 40.000000 48.0 \n",
"4 28.966400 36.478400 47.000000 56.0 \n",
"8 31.624223 40.017216 50.968750 60.0 \n",
"16 33.511633 42.003923 52.968689 62.0 \n",
"32 34.510048 43.003901 53.968689 63.0 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(avg_E_df)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x22775616630>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x864 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tE_df = avg_E_df.transpose()\n",
"ax = tE_df.plot(kind=\"bar\", figsize=(15,12), fontsize=22)\n",
"ax.set_xlabel(\"p - probability of infection\", fontsize=22)\n",
"ax.set_ylabel(\"E(p,s) for 32 samples\", fontsize=22)\n",
"ax.legend(title=\"pool size\", prop=dict(size=22), title_fontsize='22')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment