Skip to content

Instantly share code, notes, and snippets.

@HYChou0515
Last active February 19, 2022 00:38
Show Gist options
  • Save HYChou0515/7f9b17b6116c48098b804d01528b624c to your computer and use it in GitHub Desktop.
Save HYChou0515/7f9b17b6116c48098b804d01528b624c to your computer and use it in GitHub Desktop.
poisson_mixture with em algorithm
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d04a5a8f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.75it/s]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from functools import lru_cache, partial\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scipy as sp\n",
"import scipy.special\n",
"import seaborn as sns\n",
"from tqdm import tqdm\n",
"\n",
"rng = np.random.default_rng(0)\n",
"\n",
"real_N = 5\n",
"Li = np.array([123, 234, 543, 321, 599]) * 4\n",
"L = sum(Li)\n",
"real_pi = Li.astype(float) / L\n",
"real_lam = np.array([3, 6, 12, 20, 30])\n",
"\n",
"data = np.concatenate([\n",
" rng.poisson(3, Li[0]),\n",
" rng.poisson(6, Li[1]),\n",
" rng.poisson(12, Li[2]),\n",
" rng.poisson(20, Li[3]),\n",
" rng.poisson(30, Li[4]),\n",
"])\n",
"\n",
"N = 5\n",
"p = np.empty((L, N))\n",
"q = np.empty((L, N))\n",
"sumpg = np.empty(L)\n",
"pi = rng.dirichlet(np.ones(N, float))\n",
"lam = rng.exponential(10, N)\n",
"\n",
"\n",
"def real_pdf(x, k):\n",
" return np.power(real_lam[k],\n",
" x) / sp.special.factorial(x) * np.exp(-real_lam[k])\n",
"\n",
"\n",
"@lru_cache\n",
"def fact(x):\n",
" return sp.special.factorial(x)\n",
"\n",
"\n",
"@lru_cache\n",
"def pdf(x, k):\n",
" return np.power(lam[k], x) / fact(x) * np.exp(-lam[k])\n",
"\n",
"\n",
"def em_step():\n",
" # E step\n",
" for i in range(L):\n",
" sumpg[i] = 0\n",
" for k in range(N):\n",
" sumpg[i] += pi[k] * pdf(data[i], k)\n",
"\n",
" for k in range(N):\n",
" for i in range(L):\n",
" p[i, k] = pi[k] * pdf(data[i], k) / sumpg[i]\n",
" p[:] = p / p.sum(axis=1).reshape(-1, 1)\n",
"\n",
" # M step\n",
" sump = p.sum(axis=0)\n",
" sumpx = p.T @ data\n",
" lam[:] = sumpx / sump\n",
" pi[:] = sump / L\n",
"\n",
"\n",
"for t in tqdm(range(100)):\n",
" em_step()\n",
"\n",
"\n",
"def mixture(x, pdfs, p):\n",
" s = 0\n",
" for pp, ppdf in zip(p, pdfs):\n",
" s += pp * ppdf(x)\n",
" return s\n",
"\n",
"\n",
"est_mixture = np.vectorize(\n",
" partial(mixture, pdfs=[partial(pdf, k=i) for i in range(N)], p=pi))\n",
"real_mixture = np.vectorize(\n",
" partial(mixture,\n",
" pdfs=[partial(real_pdf, k=i) for i in range(real_N)],\n",
" p=real_pi))\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax2 = ax.twinx()\n",
"sns.histplot(data, ax=ax)\n",
"xx = np.linspace(0, 40, 1000)\n",
"ax2.plot(xx, est_mixture(xx), label='est')\n",
"ax2.plot(xx, real_mixture(xx), label='real')\n",
"ax2.legend()\n",
"fig.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e1db19f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>esti_lam</th>\n",
" <th>esti_pi</th>\n",
" <th>real_lam</th>\n",
" <th>real_pi</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3.025379</td>\n",
" <td>0.061927</td>\n",
" <td>3</td>\n",
" <td>0.067582</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5.864372</td>\n",
" <td>0.134654</td>\n",
" <td>6</td>\n",
" <td>0.128571</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>12.277207</td>\n",
" <td>0.311448</td>\n",
" <td>12</td>\n",
" <td>0.298352</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>20.969611</td>\n",
" <td>0.185048</td>\n",
" <td>20</td>\n",
" <td>0.176374</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>30.361888</td>\n",
" <td>0.306923</td>\n",
" <td>30</td>\n",
" <td>0.329121</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" esti_lam esti_pi real_lam real_pi\n",
"0 3.025379 0.061927 3 0.067582\n",
"1 5.864372 0.134654 6 0.128571\n",
"2 12.277207 0.311448 12 0.298352\n",
"3 20.969611 0.185048 20 0.176374\n",
"4 30.361888 0.306923 30 0.329121"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat(\n",
" [\n",
" pd.DataFrame(list(zip(lam, pi)), columns=[\n",
" \"esti_lam\", \"esti_pi\"\n",
" ]).sort_values(\"esti_lam\").reset_index(drop=True),\n",
" pd.DataFrame(list(zip(real_lam, real_pi)),\n",
" columns=[\n",
" \"real_lam\", \"real_pi\"\n",
" ]).sort_values(\"real_lam\").reset_index(drop=True),\n",
" ],\n",
" axis=1,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jupyter-home",
"language": "python",
"name": "jupyter-home"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment