Created
May 11, 2023 09:37
-
-
Save StoneRIeverKS/f56e1a29cf2d215fec8227cb358511a4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# データの読み込み" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2023-05-11T09:36:24.660889Z", | |
"start_time": "2023-05-11T09:36:24.215634Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# トイデータを読み込む\n", | |
"import pandas as pd\n", | |
"df = pd.read_csv(\"toy_data.csv\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# AdaBoostの実装" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## ハードコーディングして解消したコード" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2023-05-11T09:36:24.685457Z", | |
"start_time": "2023-05-11T09:36:24.660889Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"import numpy as np\n", | |
"\n", | |
"class AdaboostHandmade:\n", | |
" '''\n", | |
" n_estimetors:学習する弱分類器の個数\n", | |
" weak_learner:弱分類器(DecisionTreeClassifier)を想定\n", | |
" '''\n", | |
" \n", | |
" def __init__(self, n_estimators, WeakLearner, **params):\n", | |
" self._n_estimators = n_estimators\n", | |
" self._WeakLearner = WeakLearner\n", | |
" self._params = params\n", | |
" \n", | |
" \n", | |
" @property\n", | |
" def n_estimators(self):\n", | |
" return self._n_estimators\n", | |
" \n", | |
" @property\n", | |
" def WeakLearner(self):\n", | |
" return self._WeakLearner\n", | |
" \n", | |
" @property\n", | |
" def params(self):\n", | |
" return self._params\n", | |
" \n", | |
" def fit(self, X, y):\n", | |
" # 弱分類器を保存するためのリスト\n", | |
" self.weak_learner_list = []\n", | |
" # alpha_mを保存するためのリスト\n", | |
" self.alpha_m_list = []\n", | |
" \n", | |
" # 重み\n", | |
" weight = [1/len(X)]*len(X)\n", | |
" weight = np.array(weight)\n", | |
" \n", | |
" # 弱分類器を連続して作成する\n", | |
" for m in range(self.n_estimators):\n", | |
" # 弱分類器のインスタンスを作成する\n", | |
" weak_learner = self._WeakLearner()\n", | |
" weak_learner.set_params(**self._params)\n", | |
" \n", | |
" # 弱分類器の学習\n", | |
" weak_learner.fit(X, y, sample_weight=weight)\n", | |
" self.weak_learner_list.append(weak_learner)\n", | |
" # err_mを計算する\n", | |
" y_pred = weak_learner.predict(X)\n", | |
" err_rate = (y != y_pred).astype(\"float\")\n", | |
" err_m_numerator = np.dot(weight, err_rate)\n", | |
" err_m_denominator = np.sum(weight)\n", | |
" err_m = err_m_numerator/err_m_denominator\n", | |
" \n", | |
" # alpha_mを計算する\n", | |
" self.alpha_m = math.log((1-err_m)/err_m)\n", | |
" self.alpha_m_list.append(self.alpha_m)\n", | |
" \n", | |
" # 重みを更新する\n", | |
" weight = weight*np.array([math.exp(tmp) for tmp in self.alpha_m*err_rate])\n", | |
" ## 重みの正則化\n", | |
" weight = weight/np.sum(weight)\n", | |
" \n", | |
" return self.weak_learner_list\n", | |
" \n", | |
" def predict(self, X):\n", | |
" pred_y_rate_list = np.array([weak_learner.predict(X) for weak_learner in self.weak_learner_list])\n", | |
" \n", | |
" # 弱分類器が異なることをわかりやすくするためのprint文\n", | |
" for weak_learner in self.weak_learner_list:\n", | |
" print(\"predict:\", weak_learner.predict(X))\n", | |
" \n", | |
" pred_y_rate_list_decode = 2*pred_y_rate_list-1\n", | |
" pred_y_rate_list_encode = np.apply_along_axis(lambda gx:np.dot(self.alpha_m_list, gx) , axis=0, arr=pred_y_rate_list_decode) # 自由度カイ二乗分布にしたがう乱数の作成 \n", | |
" pred_y_list_decode = (pred_y_rate_list_encode>=0).astype(\"float\")\n", | |
" \n", | |
" return pred_y_list_decode" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2023-05-11T09:36:26.644583Z", | |
"start_time": "2023-05-11T09:36:24.689475Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 1. 0. 1.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 1. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 1. 0. ... 0. 1. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 1. ... 1. 0. 1.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 1. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 1. ... 0. 1. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 1.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [1. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [1. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 1. 0. ... 0. 1. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 1. ... 0. 0. 1.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 1. 0. 1.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 1.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 1. 1. ... 1. 1. 1.]\n", | |
"predict: [0. 0. 0. ... 0. 0. 0.]\n", | |
"predict: [1. 0. 0. ... 0. 0. 1.]\n", | |
"acurracy_score: 0.843\n", | |
"confusion_matrix:\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[966, 49],\n", | |
" [265, 720]], dtype=int64)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from sklearn.tree import DecisionTreeClassifier\n", | |
"\n", | |
"features_name = [\"X1\", \"X2\", \"X3\", \"X4\", \"X5\", \"X6\", \"X7\", \"X8\", \"X9\", \"X10\"]\n", | |
"adaboost_handmade = AdaboostHandmade(n_estimators=100, WeakLearner=DecisionTreeClassifier, max_depth=1, random_state=501)\n", | |
"weak_lerner_list = adaboost_handmade.fit(X=df[features_name], y=df[\"label\"])\n", | |
"pred_labels = adaboost_handmade.predict(X=df[features_name])\n", | |
"\n", | |
"from sklearn.metrics import accuracy_score, confusion_matrix\n", | |
"print(\"acurracy_score:\", accuracy_score(df[\"label\"], pred_labels))\n", | |
"print(\"confusion_matrix:\")\n", | |
"display(confusion_matrix(df[\"label\"], pred_labels))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment