Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save StoneRIeverKS/f56e1a29cf2d215fec8227cb358511a4 to your computer and use it in GitHub Desktop.
Save StoneRIeverKS/f56e1a29cf2d215fec8227cb358511a4 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# データの読み込み"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2023-05-11T09:36:24.660889Z",
"start_time": "2023-05-11T09:36:24.215634Z"
}
},
"outputs": [],
"source": [
"# トイデータを読み込む\n",
"import pandas as pd\n",
"df = pd.read_csv(\"toy_data.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AdaBoostの実装"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ハードコーディングして解消したコード"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2023-05-11T09:36:24.685457Z",
"start_time": "2023-05-11T09:36:24.660889Z"
}
},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"\n",
"class AdaboostHandmade:\n",
" '''\n",
" n_estimetors:学習する弱分類器の個数\n",
" weak_learner:弱分類器(DecisionTreeClassifier)を想定\n",
" '''\n",
" \n",
" def __init__(self, n_estimators, WeakLearner, **params):\n",
" self._n_estimators = n_estimators\n",
" self._WeakLearner = WeakLearner\n",
" self._params = params\n",
" \n",
" \n",
" @property\n",
" def n_estimators(self):\n",
" return self._n_estimators\n",
" \n",
" @property\n",
" def WeakLearner(self):\n",
" return self._WeakLearner\n",
" \n",
" @property\n",
" def params(self):\n",
" return self._params\n",
" \n",
" def fit(self, X, y):\n",
" # 弱分類器を保存するためのリスト\n",
" self.weak_learner_list = []\n",
" # alpha_mを保存するためのリスト\n",
" self.alpha_m_list = []\n",
" \n",
" # 重み\n",
" weight = [1/len(X)]*len(X)\n",
" weight = np.array(weight)\n",
" \n",
" # 弱分類器を連続して作成する\n",
" for m in range(self.n_estimators):\n",
" # 弱分類器のインスタンスを作成する\n",
" weak_learner = self._WeakLearner()\n",
" weak_learner.set_params(**self._params)\n",
" \n",
" # 弱分類器の学習\n",
" weak_learner.fit(X, y, sample_weight=weight)\n",
" self.weak_learner_list.append(weak_learner)\n",
" # err_mを計算する\n",
" y_pred = weak_learner.predict(X)\n",
" err_rate = (y != y_pred).astype(\"float\")\n",
" err_m_numerator = np.dot(weight, err_rate)\n",
" err_m_denominator = np.sum(weight)\n",
" err_m = err_m_numerator/err_m_denominator\n",
" \n",
" # alpha_mを計算する\n",
" self.alpha_m = math.log((1-err_m)/err_m)\n",
" self.alpha_m_list.append(self.alpha_m)\n",
" \n",
" # 重みを更新する\n",
" weight = weight*np.array([math.exp(tmp) for tmp in self.alpha_m*err_rate])\n",
" ## 重みの正則化\n",
" weight = weight/np.sum(weight)\n",
" \n",
" return self.weak_learner_list\n",
" \n",
" def predict(self, X):\n",
" pred_y_rate_list = np.array([weak_learner.predict(X) for weak_learner in self.weak_learner_list])\n",
" \n",
" # 弱分類器が異なることをわかりやすくするためのprint文\n",
" for weak_learner in self.weak_learner_list:\n",
" print(\"predict:\", weak_learner.predict(X))\n",
" \n",
" pred_y_rate_list_decode = 2*pred_y_rate_list-1\n",
" pred_y_rate_list_encode = np.apply_along_axis(lambda gx:np.dot(self.alpha_m_list, gx) , axis=0, arr=pred_y_rate_list_decode) # 自由度カイ二乗分布にしたがう乱数の作成 \n",
" pred_y_list_decode = (pred_y_rate_list_encode>=0).astype(\"float\")\n",
" \n",
" return pred_y_list_decode"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2023-05-11T09:36:26.644583Z",
"start_time": "2023-05-11T09:36:24.689475Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 1. 0. 1.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 1. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 1. 0. ... 0. 1. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 1. ... 1. 0. 1.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 1. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 1. ... 0. 1. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 1.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [1. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [1. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 1. 0. ... 0. 1. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 1. ... 0. 0. 1.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 1. 0. 1.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [0. 0. 0. ... 0. 0. 1.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 1. 1. ... 1. 1. 1.]\n",
"predict: [0. 0. 0. ... 0. 0. 0.]\n",
"predict: [1. 0. 0. ... 0. 0. 1.]\n",
"acurracy_score: 0.843\n",
"confusion_matrix:\n"
]
},
{
"data": {
"text/plain": [
"array([[966, 49],\n",
" [265, 720]], dtype=int64)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"features_name = [\"X1\", \"X2\", \"X3\", \"X4\", \"X5\", \"X6\", \"X7\", \"X8\", \"X9\", \"X10\"]\n",
"adaboost_handmade = AdaboostHandmade(n_estimators=100, WeakLearner=DecisionTreeClassifier, max_depth=1, random_state=501)\n",
"weak_lerner_list = adaboost_handmade.fit(X=df[features_name], y=df[\"label\"])\n",
"pred_labels = adaboost_handmade.predict(X=df[features_name])\n",
"\n",
"from sklearn.metrics import accuracy_score, confusion_matrix\n",
"print(\"acurracy_score:\", accuracy_score(df[\"label\"], pred_labels))\n",
"print(\"confusion_matrix:\")\n",
"display(confusion_matrix(df[\"label\"], pred_labels))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment