Skip to content

Instantly share code, notes, and snippets.

@hoto17296
Last active May 13, 2018 05:17
Show Gist options
  • Save hoto17296/0f5c6ef3aad4a9f1127f05bf00d7854e to your computer and use it in GitHub Desktop.
Save hoto17296/0f5c6ef3aad4a9f1127f05bf00d7854e to your computer and use it in GitHub Desktop.
ちゅら.ai アルゴリズム実装会 #1「決定木」
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 参考\n",
"- [Rによるデータサイエンス13「樹木モデル」](https://www.slideshare.net/takemikami/r13-9821987)\n",
"- [2-5. ジニ係数 | 統計学の時間 | 統計WEB](https://bellcurve.jp/statistics/course/3798.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ジニ係数を計算できるようにする"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.6/site-packages/matplotlib/__init__.py:1067: UserWarning: Duplicate key in file \"/opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/matplotlibrc\", line #620\n",
" (fname, cnt))\n"
]
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x0</th>\n",
" <th>x1</th>\n",
" <th>x2</th>\n",
" <th>x3</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" x0 x1 x2 x3 y\n",
"0 5.1 3.5 1.4 0.2 0\n",
"1 4.9 3.0 1.4 0.2 0\n",
"2 4.7 3.2 1.3 0.2 0\n",
"3 4.6 3.1 1.5 0.2 0\n",
"4 5.0 3.6 1.4 0.2 0"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iris = load_iris()\n",
"df = pd.DataFrame(iris.data, columns=['x{}'.format(i) for i in range(4)])\n",
"df['y'] = iris.target\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7f7d273b49b0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.6/site-packages/matplotlib/font_manager.py:1328: UserWarning: findfont: Font family ['TakaoPGothic'] not found. Falling back to DejaVu Sans\n",
" (prop.get_family(), self.defaultFamily[fontext]))\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1152x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(16, 4))\n",
"ax = fig.add_subplot(121)\n",
"ax.set_title('Sepal')\n",
"ax.set_xlabel('Length (x0)')\n",
"ax.set_ylabel('Width (x1)')\n",
"ax.scatter(x=df.x0, y=df.x1, c=df.y, alpha=0.5)\n",
"ax = fig.add_subplot(122)\n",
"ax.set_title('Petal')\n",
"ax.set_xlabel('Length (x2)')\n",
"ax.set_ylabel('Width (x3)')\n",
"ax.scatter(x=df.x2, y=df.x3, c=df.y, alpha=0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ひとまず Iris データセット全体のジニ係数を計算する"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.66666666666666674"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def gini_index(df, target='y'):\n",
" return 1 - ((df.groupby(target).size() / len(df)) ** 2).sum()\n",
"\n",
"gini_index(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## x0 で二分割した場合のジニ係数を計算する\n",
"特に意味は無いけど x0 の平均値で分けてみる"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.17476190476190478"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def binary(df, root_gi, x, boundary):\n",
" df0 = df[x < boundary]\n",
" df1 = df[x >= boundary]\n",
" gi = root_gi - (len(df0) * gini_index(df0) + len(df1) * gini_index(df1)) / len(df)\n",
" return gi, (df0, df1)\n",
"\n",
"gi, _ = binary(df, gini_index(df), df.x0, df.x0.mean())\n",
"gi"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## すべての説明変数の中でジニ係数が最も高い二分境界を計算する"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class Node:\n",
"\n",
" def classify(self):\n",
" raise NotImplementedError()\n",
"\n",
"\n",
"class Tree(Node):\n",
" \n",
" def __init__(self, col, boundary, gi, depth=0):\n",
" self.col = col\n",
" self.boundary = boundary\n",
" self.gi = gi\n",
" self.depth = depth\n",
" self.nodes = []\n",
"\n",
" def __repr__(self):\n",
" text = '<{} col={col} boundary={boundary} gi={gi} depth={depth}>'.format(type(self).__name__, **self.__dict__)\n",
" for node in self.nodes:\n",
" text += '\\n'\n",
" text += '\\t' * (self.depth + 1)\n",
" text += repr(node)\n",
" return text\n",
"\n",
" def classify(self, row):\n",
" if row[self.col] < self.boundary:\n",
" return self.nodes[0].classify(row)\n",
" else:\n",
" return self.nodes[1].classify(row)\n",
"\n",
"\n",
"class Leaf(Node):\n",
" \n",
" def __init__(self, category):\n",
" self.category = category\n",
" \n",
" def __repr__(self):\n",
" return '<{} category={category}>'.format(type(self).__name__, **self.__dict__)\n",
"\n",
" def classify(self, row):\n",
" return self.category"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Tree col=x3 boundary=1.0 gi=0.3333333333333334 depth=0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def search_boundary(df, target='y'):\n",
" root_gi = gini_index(df, target)\n",
" max_node = Tree(None, None, 0.0)\n",
" max_separated = ()\n",
" for col in df.columns:\n",
" if col == target: continue\n",
" for boundary in df[col].drop_duplicates().sort_values():\n",
" gi, separated = binary(df, root_gi, df[col], boundary)\n",
" if max_node.gi <= gi:\n",
" max_node = Tree(col, boundary, gi)\n",
" max_separated = separated\n",
" return max_node, max_separated\n",
"\n",
"tree, separated = search_boundary(df)\n",
"tree"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Tree col=x3 boundary=1.8 gi=0.38969404186795487 depth=0>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"search_boundary(separated[1])[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 分割しまくって木を作る"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def make_tree(df, depth=0, max_depth=4, target='y'):\n",
" \"\"\"Tree または Leaf オブジェクトを返す\"\"\"\n",
" global importance\n",
" categories = df[target].unique()\n",
" assert categories.size > 0\n",
" if categories.size == 1:\n",
" return Leaf(categories[0])\n",
" if depth > max_depth:\n",
" return Leaf(df.groupby(target).size().idxmax())\n",
" tree, separated = search_boundary(df, target)\n",
" importance[tree.col] += tree.gi * len(df)\n",
" tree.depth = depth\n",
" tree.nodes = [make_tree(_df, depth+1, max_depth) for _df in separated]\n",
" return tree"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Tree col=x3 boundary=1.0 gi=0.3333333333333334 depth=0>\n",
"\t<Leaf category=0>\n",
"\t<Tree col=x3 boundary=1.8 gi=0.38969404186795487 depth=1>\n",
"\t\t<Tree col=x2 boundary=5.0 gi=0.08239026063100137 depth=2>\n",
"\t\t\t<Tree col=x3 boundary=1.7 gi=0.04079861111111116 depth=3>\n",
"\t\t\t\t<Leaf category=1>\n",
"\t\t\t\t<Leaf category=2>\n",
"\t\t\t<Tree col=x3 boundary=1.6 gi=0.2222222222222222 depth=3>\n",
"\t\t\t\t<Leaf category=2>\n",
"\t\t\t\t<Tree col=x2 boundary=5.8 gi=0.4444444444444444 depth=4>\n",
"\t\t\t\t\t<Leaf category=1>\n",
"\t\t\t\t\t<Leaf category=2>\n",
"\t\t<Tree col=x2 boundary=4.9 gi=0.013547574039067499 depth=2>\n",
"\t\t\t<Tree col=x1 boundary=3.2 gi=0.4444444444444444 depth=3>\n",
"\t\t\t\t<Leaf category=2>\n",
"\t\t\t\t<Leaf category=1>\n",
"\t\t\t<Leaf category=2>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"importance = pd.Series(index=df.columns.drop('y')).fillna(0)\n",
"tree = make_tree(df)\n",
"importance /= importance.sum()\n",
"tree"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"x0 0.000000\n",
"x1 0.013333\n",
"x2 0.064056\n",
"x3 0.922611\n",
"dtype: float64"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"importance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 分類できるようにする"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(Tree クラスと Leaf クラスに classify メソッドを実装した)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree.classify(df.iloc[100])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 0\n",
"1 0\n",
"2 0\n",
"3 0\n",
"4 0\n",
"dtype: int64"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = pd.Series([tree.classify(row) for _, row in df.iterrows()])\n",
"pred.head()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(df.y == pred).all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"学習データをそのまま突っ込んだらすべて正しく分類できたっぽい"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 精度を計算する\n",
"ランダムに分割した 80% で学習 20% で検証を 20 回やって正解率の平均を出す"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.94833333333333325"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"precisions = []\n",
"for _ in range(20):\n",
" shuffled_df = df.reindex(np.random.permutation(df.index)).reset_index(drop=True)\n",
" p = int(0.8 * len(df))\n",
" train_df = shuffled_df.iloc[:p, :].reset_index(drop=True)\n",
" test_df = shuffled_df.iloc[p:, :].reset_index(drop=True)\n",
" # train\n",
" tree = make_tree(train_df)\n",
" # predict\n",
" pred = pd.Series([tree.classify(row) for _, row in test_df.iterrows()])\n",
" # evaluate\n",
" precisions.append((pred == test_df.y).sum() / len(test_df))\n",
"np.array(precisions).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 書き直し\n",
"- scikit-learn っぽいインタフェースにしたい\n",
"- Tree, Leaf クラスは上記のものを使いまわしている"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"class BinaryDecisionTree:\n",
" \n",
" def __init__(self, max_depth=4):\n",
" \"\"\"モデルのパラメータを設定する\"\"\"\n",
" self.max_depth = 4\n",
" \n",
" def fit(self, df, target='y'):\n",
" \"\"\"学習データから決定木を構築する\"\"\"\n",
" self.importance = pd.Series(index=df.columns.drop(target)).fillna(0)\n",
" self.target = target\n",
" self.tree = self._make_tree(df)\n",
"\n",
" def _make_tree(self, df, depth=0):\n",
" categories = df[self.target].unique()\n",
" assert categories.size > 0\n",
" if categories.size == 1:\n",
" return Leaf(categories[0])\n",
" if depth > self.max_depth:\n",
" return Leaf(df.groupby(target).size().idxmax())\n",
" tree, separated = search_boundary(df)\n",
" self.importance[tree.col] += tree.gi * len(df)\n",
" tree.depth = depth # TODO ダサい\n",
" tree.nodes = [self._make_tree(_df, depth+1) for _df in separated]\n",
" return tree\n",
" \n",
" def _search_boundary(self, df):\n",
" \"\"\"与えられたデータの中で最良の分割条件を探索する\"\"\"\n",
" root_gi = gini_index(df, self.target)\n",
" max_node = Tree(None, None, 0.0)\n",
" max_separated = ()\n",
" for col in df.columns:\n",
" if col == target: continue\n",
" for boundary in df[col].drop_duplicates().sort_values():\n",
" gi, separated = self._binary(df, root_gi, df[col], boundary)\n",
" if max_node.gi <= gi:\n",
" max_node = Tree(col, boundary, gi)\n",
" max_separated = separated\n",
" return max_node, max_separated\n",
" \n",
" def _binary(self, df, root_gi, x, boundary):\n",
" \"\"\"与えられた分割条件でデータを分割した場合の情報利得を計算する\"\"\"\n",
" df0 = df[x < boundary]\n",
" df1 = df[x >= boundary]\n",
" gi = root_gi - (len(df0) * self._gini_index(df0) + len(df1) * self._gini_index(df1)) / len(df)\n",
" return gi, (df0, df1)\n",
" \n",
" def _gini_index(df, target='y'):\n",
" \"\"\"与えられたデータのジニ係数を計算する\"\"\"\n",
" return 1 - ((df.groupby(target).size() / len(df)) ** 2).sum()\n",
"\n",
" def predict(self, df):\n",
" \"\"\"説明変数から目的変数を予測する\"\"\"\n",
" return pd.Series([self.tree.classify(row) for _, row in df.iterrows()])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"shuffled_df = df.reindex(np.random.permutation(df.index)).reset_index(drop=True)\n",
"p = int(0.8 * len(df))\n",
"train_df = shuffled_df.iloc[:p, :].reset_index(drop=True)\n",
"test_df = shuffled_df.iloc[p:, :].reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"x0 0.000000\n",
"x1 1.333333\n",
"x2 6.240821\n",
"x3 72.359179\n",
"dtype: float64"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = BinaryDecisionTree()\n",
"model.fit(train_df)\n",
"model.importance"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Tree col=x3 boundary=1.0 gi=0.32464769647696473 depth=0>\n",
"\t<Leaf category=0>\n",
"\t<Tree col=x3 boundary=1.8 gi=0.3672941686999054 depth=1>\n",
"\t\t<Tree col=x2 boundary=5.0 gi=0.09339949590422182 depth=2>\n",
"\t\t\t<Tree col=x3 boundary=1.7 gi=0.04875000000000007 depth=3>\n",
"\t\t\t\t<Leaf category=1>\n",
"\t\t\t\t<Leaf category=2>\n",
"\t\t\t<Tree col=x3 boundary=1.6 gi=0.2222222222222222 depth=3>\n",
"\t\t\t\t<Leaf category=2>\n",
"\t\t\t\t<Tree col=x2 boundary=5.8 gi=0.4444444444444444 depth=4>\n",
"\t\t\t\t\t<Leaf category=1>\n",
"\t\t\t\t\t<Leaf category=2>\n",
"\t\t<Tree col=x2 boundary=4.9 gi=0.01697530864197544 depth=2>\n",
"\t\t\t<Tree col=x1 boundary=3.2 gi=0.4444444444444444 depth=3>\n",
"\t\t\t\t<Leaf category=2>\n",
"\t\t\t\t<Leaf category=1>\n",
"\t\t\t<Leaf category=2>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.tree"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = model.predict(test_df)\n",
"(pred == test_df.y).sum() / len(test_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment