Skip to content

Instantly share code, notes, and snippets.

@narrowlyapplicable
Last active March 13, 2023 17:54
Show Gist options
  • Save narrowlyapplicable/bf6110beaca11934d47838dfbe2d19f1 to your computer and use it in GitHub Desktop.
Save narrowlyapplicable/bf6110beaca11934d47838dfbe2d19f1 to your computer and use it in GitHub Desktop.
pystanによるWAIC & WBIC の計算例
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# WAIC&WBIC with pystan"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 人工データに対し、pystanを用いてWAICおよびWBICを算出した。\n",
" - データは <https://github.com/narrowlyapplicable/peak_separation_whth_WBIC> で作成した信号ピーク模擬データを使用した。 \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 必要なライブラリのインポート"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- pythonのバージョンは3.6.6\n",
"- 使用ライブラリ群は以下\n",
" - numpy 1.15.1\n",
" - pandas 0.23.4\n",
" - pystan 2.17.1.0\n",
" - matplotlib 2.2.2\n",
" - seaborn 0.9.0"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import scipy.stats as sct\n",
"from pystan import StanModel\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"plt.style.use('ggplot')\n",
"sns.set_palette('deep')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. データの作成"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Lorenz関数(Cauchy分布の確率密度関数)型のピーク信号が、複数個重畳して測定される状況を考える。\n",
"- scipy.statsを用いて3本のLorenz関数を重ね、さらにガウスノイズを付加したデータを作成した。\n",
" - 2つのピークが重畳 & 1本のショルダーピークが存在する、という状況を想定している。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(1)\n",
"x = np.arange(-10, 10.5, 0.5)\n",
"k_true = 3\n",
"peak1 = sct.cauchy.pdf(x = x, loc = 0, scale = 1.0)\n",
"# peak2 = 0.16 * sct.cauchy.pdf(x = x, loc = 1, scale = 0.4)\n",
"peak3 = 0.09 * sct.cauchy.pdf(x = x, loc = -0.7, scale = 0.3)\n",
"peak4 = 0.36 * sct.cauchy.pdf(x = x, loc = 3.0, scale = 0.6)\n",
"\n",
"data = pd.DataFrame(data={\"x\":x, \"y\":peak1+peak3+peak4+np.random.normal(size=x.shape[0], scale=0.01)})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.plot(x=\"x\", marker=\".\", linestyle=\"--\") #pandasのplot使用\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. モデルの定義\n",
"- モデルには以下の事前知識を仮定する。\n",
" - ピークの幅には上限があり、Cauchy分布確率密度関数のscaleは高々2程度である。\n",
" - 観測ノイズは正規分布に従い、そのsdは高々0.05である。\n",
"- これらの条件を反映し、弱情報事前分布として半t分布を用いる。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1. WAIC用のstanコード"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"code_waic = \"\"\"\n",
"data {\n",
" int N;\n",
" int K;\n",
" vector[N] X;\n",
" vector[N] Y;\n",
"}\n",
"\n",
"parameters {\n",
" vector[K] mu;\n",
" vector<lower=0>[K] sigma;\n",
" real<lower=0> s_mu;\n",
" real<lower=0> s_noise;\n",
"}\n",
"\n",
"model {\n",
" mu ~ normal(0, s_mu);\n",
" sigma ~ student_t(4, 0, 2);\n",
" s_noise ~ student_t(4, 0, 0.05);\n",
" for (n in 1:N) {\n",
" vector[K] line;\n",
" for (k in 1:K){\n",
" line[k] = log(sigma[k]^2) + cauchy_lpdf(X[n] | mu[k], sigma[k]);\n",
" }\n",
" target += normal_lpdf(Y[n] | exp(log_sum_exp(line)), s_noise);\n",
" }\n",
"}\n",
"\n",
"generated quantities {\n",
" vector[N] log_likelihood;\n",
" for(n in 1:N){\n",
" vector[K] line;\n",
" for (k in 1:K){\n",
" line[k] = log(sigma[k]^2) + cauchy_lpdf(X[n] | mu[k], sigma[k]);\n",
" }\n",
" log_likelihood[n] = normal_lpdf(Y[n] | exp(log_sum_exp(line)), s_noise);\n",
" }\n",
"}\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# コンパイル済みのモデルがあれば読み込み。なければコンパイルし保存。\n",
"try:\n",
" with open('./model_peakwaic.pkl', \"rb\") as f:\n",
" stanmodel_waic = pickle.load(f)\n",
"except FileNotFoundError:\n",
" stanmodel_waic = StanModel(model_code=code_waic, model_name=\"model_peakwaic\")\n",
" with open('./model_peakwaic.pkl', \"wb\") as f:\n",
" pickle.dump(stanmodel_waic, f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.2. WBIC用のstanコード\n",
"- modelブロックでは、target記法により逆温度 $1/log(N)$ [Nはデータ数] の事後分布を定義している。\n",
"- generated quantitiesブロックにおいて、上記事後分布の下での対数尤度を算出している。\n",
"- stanコードは以前の実装例 <https://github.com/narrowlyapplicable/peak_separation_whth_WBIC/blob/master/wbic-mix-cauchy_lpdf.stan> そのまま。\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"code_wbic = \"\"\"\n",
"data {\n",
" int N;\n",
" int K;\n",
" vector[N] X;\n",
" vector[N] Y;\n",
"}\n",
"\n",
"parameters {\n",
" vector[K] mu;\n",
" vector<lower=0>[K] sigma;\n",
" real<lower=0> s_mu;\n",
" real<lower=0> s_noise;\n",
"}\n",
"\n",
"model {\n",
" mu ~ normal(0, s_mu);\n",
" sigma ~ student_t(4, 0, 2);\n",
" s_noise ~ student_t(4, 0, 0.05);\n",
" for (n in 1:N) {\n",
" vector[K] line;\n",
" for (k in 1:K){\n",
" line[k] = log(sigma[k]^2) + cauchy_lpdf(X[n] | mu[k], sigma[k]);\n",
" }\n",
" target += 1/log(N) * normal_lpdf(Y[n] | exp(log_sum_exp(line)), s_noise);\n",
" }\n",
"}\n",
"\n",
"generated quantities {\n",
" vector[N] log_likelihood;\n",
" for(n in 1:N){\n",
" vector[K] line;\n",
" for (k in 1:K){\n",
" line[k] = log(sigma[k]^2) + cauchy_lpdf(X[n] | mu[k], sigma[k]);\n",
" }\n",
" log_likelihood[n] = normal_lpdf(Y[n] | exp(log_sum_exp(line)), s_noise);\n",
" }\n",
"}\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# コンパイル済みのモデルがあれば読み込み。なければコンパイルし保存。\n",
"try:\n",
" with open('./model_peakwbic.pkl', \"rb\") as f:\n",
" stanmodel_wbic = pickle.load(f)\n",
"except FileNotFoundError:\n",
" stanmodel_wbic = StanModel(model_code=code_wbic, model_name=\"model_peakwbic\")\n",
" with open('./model_peakwbic.pkl', \"wb\") as f:\n",
" pickle.dump(stanmodel_wbic, f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. MCMC実行 & IC算出\n",
"- ピーク本数は少なくとも2本以上であることから、ピーク本数候補は2,3,4,5の4通りとした。\n",
"- この4通りのについて、WAICとWBICを算出した。\n",
" - こちらもWBICについては以前の実装例そのまま。\n",
" - WAICについては導出がやや煩雑となるため、waic()関数を作成している。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.1. WAIC算出"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def waic(samples):\n",
" tE = - np.mean(np.log(np.mean(np.exp(samples), axis=0)))\n",
" fVar = np.sum(np.mean(samples**2, axis=0) - np.mean(samples, axis=0)**2)\n",
" waic = tE + fVar/samples.shape[0]\n",
" return waic"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"//anaconda/lib/python3.6/site-packages/pystan/misc.py:399: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" elif np.issubdtype(np.asarray(v).dtype, float):\n"
]
}
],
"source": [
"dict_waic = {}\n",
"for k_cand in [2,3,4,5]:\n",
" #pystan用にデータを辞書型にまとめる\n",
" standata = {\"N\":data.shape[0], \"K\":k_cand, \"X\":data[\"x\"], \"Y\":data[\"y\"]}\n",
" fit = stanmodel_waic.sampling(data = standata, iter=10000, warmup=4000, seed=1234)\n",
" ms = fit.extract()\n",
" dict_waic[k_cand] = waic(ms[\"log_likelihood\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"df_waic = pd.DataFrame(data={\"n_peak\":list(dict_waic.keys()), \"WAIC\":list(dict_waic.values())})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2. WBIC算出"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"//anaconda/lib/python3.6/site-packages/pystan/misc.py:399: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" elif np.issubdtype(np.asarray(v).dtype, float):\n"
]
}
],
"source": [
"dict_wbic = {}\n",
"for k_cand in [2,3,4,5]:\n",
" #pystan用にデータを辞書型にまとめる\n",
" standata = {\"N\":data.shape[0], \"K\":k_cand, \"X\":data[\"x\"], \"Y\":data[\"y\"]}\n",
" fit = stanmodel_wbic.sampling(data = standata, iter=10000, warmup=4000, seed=1234)\n",
" ms = fit.extract()\n",
" dict_wbic[k_cand] = - np.mean(np.sum(ms[\"log_likelihood\"], axis=1))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"df_wbic = pd.DataFrame(data={\"n_peak\":list(dict_wbic.keys()), \"WBIC\":list(dict_wbic.values())})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 結果"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 算出されたWAIC, WBIC値は以下の通り。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>n_peak</th>\n",
" <th>WAIC</th>\n",
" <th>WBIC</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>-2.684415</td>\n",
" <td>-95.823023</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3</td>\n",
" <td>-3.228095</td>\n",
" <td>-109.945743</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4</td>\n",
" <td>-3.236460</td>\n",
" <td>-100.927044</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" <td>-3.239078</td>\n",
" <td>-98.677328</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" n_peak WAIC WBIC\n",
"0 2 -2.684415 -95.823023\n",
"1 3 -3.228095 -109.945743\n",
"2 4 -3.236460 -100.927044\n",
"3 5 -3.239078 -98.677328"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.merge(df_waic, df_wbic, on=\"n_peak\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*[表]候補となるモデル(ピーク本数)に対するWAIC値およびWBIC値*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- これらの値を棒グラフとし下図に示した。\n",
" - ただしWAIC値はピーク本数3~5での変化が小さいため、この領域を拡大した形で示した。そのためピーク本数=2における値が上方に見切れている。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x504 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"### matplotlib figureの準備\n",
"f, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7))#, sharex=True)\n",
"\n",
"### WAIC\n",
"# ベースライン調整\n",
"baseline_waic = 4\n",
"tempDf_waic = df_waic.copy()\n",
"tempDf_waic[\"WAIC\"] += baseline_waic\n",
"# plot WAIC-values\n",
"ax1.set_title('WAIC')\n",
"sns.barplot(x = \"n_peak\", y = \"WAIC\", palette=\"rocket\",data=tempDf_waic, ax=ax1, bottom=-baseline_waic)\n",
"sns.lineplot(x=[0,1,2,3], y = \"WAIC\", palette=\"rocket\",data=df_waic, ax=ax1, marker=\"o\")\n",
"ax1.set_xlabel('n_peak')\n",
"ax1.set_ylim([-3.3,-3.2])\n",
"\n",
"### WBIC\n",
"# ベースライン調整\n",
"baseline_wbic = 150\n",
"tempDf_wbic = df_wbic.copy()\n",
"tempDf_wbic[\"WBIC\"] += baseline_wbic\n",
"# plot WBIC-values\n",
"ax2.set_title(\"WBIC\")\n",
"sns.barplot(x = \"n_peak\", y = \"WBIC\", palette=\"rocket\",data=tempDf_wbic, ax=ax2, bottom=-baseline_wbic)\n",
"sns.lineplot(x=[0,1,2,3], y = \"WBIC\", palette=\"rocket\",data=df_wbic, ax=ax2, marker=\"o\")\n",
"ax2.set_xlabel('n_peak')\n",
"ax2.set_ylim([-120, -90])\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*[図]上段:候補となるモデル(ピーク本数)に対するWAIC値, 下段:〃に対するWBIC値*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. まとめ\n",
"- WBICはピーク本数=3で最小値を取った。\n",
"- 一方で、WAICは本数が増えるごとに単調減少傾向にあった。\n",
"- 真のピーク本数は3本であったから、WBICいずれも真の本数を選択できていたことがわかる。対してWAICは、複雑すぎるモデルを選択してしまっている。\n",
" - WAICは予測性能が最良のモデルを選ぶ指標であり、「真のピーク本数」を推定するものではない。\n",
" - WBICは一致性をもつ規準であり、実際に真のピーク本数を当てることができている。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:anaconda]",
"language": "python",
"name": "conda-env-anaconda-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment