Skip to content

Instantly share code, notes, and snippets.

@YuigaWada
Created December 18, 2021 09:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YuigaWada/22c8138daf7ad07a680767585944e2a2 to your computer and use it in GitHub Desktop.
Save YuigaWada/22c8138daf7ad07a680767585944e2a2 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import time
PLOT = 1
class Node: # 決定木のノード
def __init__(self, X, Y, depth=0):
self.left = None
self.right = None
self.feature = None
self.threshold = None
self.X = X
self.Y = Y
self.features = [i for i in range(np.shape(X)[1])]
self.depth = depth
self.class_table = np.bincount(Y)
self.class_label = np.argmax(self.class_table) # arg{max{count(x)}}
def get_error(self): # 1 - max{P(C_i|node=self)}
table = self.class_table
return sum(table) > 0 if 1 - max(table) / sum(table) else 1
class Tree: # 決定木
def __init__(self, X, Y, max_depth):
self.root = Node(X, Y)
self.X = X
self.Y = Y
self.max_depth = max_depth
min_max_table = []
for i in range(np.shape(X)[1]):
min_max_table.append([min(X[:, i]), max(X[:, i])])
self.min_max_table = min_max_table
def gini(self, pair): # ジニ係数を計算
if sum(pair) == 0:
return 0.0
probabilities = pair / sum(pair)
gini = 1 - sum(probabilities**2)
return gini
def get_mean_array(self, x): # 各中点を計算
if len(x) == 0:
return []
return np.convolve(x, np.ones(2), mode="valid") / 2
def get_gini(self, X, Y, feature=None, value=None, separate_type=None): # 分割後のジニ係数を計算
if separate_type == None:
count_table = np.bincount(Y)
elif separate_type == "left":
count_table = np.bincount(Y[X[:, feature] <= value])
elif separate_type == "right":
count_table = np.bincount(Y[X[:, feature] > value])
count = sum(count_table)
half_gini = self.gini(count_table)
return count, half_gini
def split(self, node): # 適切な分割点を探索
X, Y = node.X, node.Y
_, gini = self.get_gini(X, Y)
maximum = -1
best = None
threshold = None
for feature in node.features:
if len(np.unique(X[:, feature])) <= 1:
continue
ix = X[:, feature].argsort()
mean_array = self.get_mean_array(np.unique(X[ix, feature]))
for value in mean_array: # 各中点で分割してジニ係数を計算
lcount, lgini =\
self.get_gini(X[ix], Y[ix], feature, value, "left")
rcount, rgini =\
self.get_gini(X[ix], Y[ix], feature, value, "right")
count = lcount + rcount
l_probability = lcount / count
r_probability = rcount / count
gain = gini - (l_probability * lgini + r_probability * rgini)
assert gain >= 0
if gain > maximum: # chmax
best = feature
threshold = value
maximum = gain
if best is None:
assert len(np.unique(X[:, feature])) <= 1
return (best, threshold)
def fit(self): # 決定木を生成
self.grow(node=self.root)
def grow(self, node): # node以降で木を成長させる
if node.depth >= self.max_depth:
return
X, Y = node.X, node.Y
feature, threshold = self.split(node)
if feature is None:
return
node.feature = feature
node.threshold = threshold
il, ir = X[:, feature] <= threshold, X[:, feature] > threshold
ndepth = node.depth + 1
node.left = Node(
X[il],
Y[il],
depth=ndepth,
)
node.right = Node(
X[ir],
Y[ir],
depth=ndepth,
)
self.grow(node.left)
self.grow(node.right)
def prune(self): # 木の剪定 (arg{min{g(x)}}の子孫を全て削除する)
stack = [self.root]
N = len(self.X)
INF = 1 << 30
mg = (INF, None) # min_g, arg{min_g}
while len(stack): # DFS
current = stack.pop()
has_left = current.left is not None
has_right = current.right is not None
is_terminal = not has_left and not has_right
if not is_terminal: # 非終端ノードはg(node)を計算
expr = self.g(current)
if expr < mg[0] and current != self.root:
mg = (expr, current)
if has_left:
stack.append(current.left)
if has_right:
stack.append(current.right)
alpha, target = mg
if target is not None:
target.left = None
target.right = None
return alpha
def get_terminal_error(self, node): # 終端nodeの誤り率 = M(t) / N , M(t):=総誤り数
N = len(self.X)
failure = 0
for y in node.Y:
failure += y != node.class_label
return failure / N
def get_nonterminal_error(self, node): # 非終端nodeの誤り率 = 再代入誤り率 * 周辺確率
p_t = len(node.X) / len(self.X) # 周辺確率
R_t = node.get_error() * p_t
return R_t
def g(self, node): # g(t) = node_tのリンクの強さ
R_t = self.get_nonterminal_error(node)
stack = [node]
terminals = 0
R_T = 0
while len(stack): # DFS
current = stack.pop()
has_left = current.left is not None
has_right = current.right is not None
if has_left:
stack.append(current.left)
if has_right:
stack.append(current.right)
if not has_left and not has_right:
R_T += self.get_terminal_error(current)
terminals += 1
alpha = R_t - R_T
alpha /= (terminals - 1)
return alpha
def draw(self): # 分割区間をグラフに表示
if not PLOT:
return
stack = [self.root]
terminals = []
while len(stack): # DFS
current = stack.pop()
if current.left is not None:
stack.append(current.left)
if current.right is not None:
stack.append(current.right)
if current.feature is None:
continue
if not current.feature:
plt.plot([current.threshold, current.threshold],
self.min_max_table[~current.feature], color="green")
else:
plt.plot(self.min_max_table[~current.feature], [current.threshold,
current.threshold], color="green")
def predict(self, X): # 識別器
predictions = np.zeros_like(X[:, 0])
for i, x in enumerate(X):
values = {}
for feature in self.root.features:
values.update({feature: x[feature]})
current = self.root
while current.depth < self.max_depth and current.feature is not None:
next = current.left if values[current.feature] < current.threshold else current.right
if next is not None:
current = next
else:
break
predictions[i] = current.class_label
return predictions
# Main
def split_data(X, Y, size): # データを分割
window = len(Y) // size
train_size, test_size = len(Y) - window, window
res = []
_X, _Y = np.concatenate((X, X)), np.concatenate((Y, Y))
for i in range(size):
offset = window * i
test_tail = offset + window
X_test, Y_test =\
_X[offset:test_tail], _Y[offset:test_tail]
X_train, Y_train =\
_X[test_tail:test_tail + train_size], _Y[test_tail:test_tail+train_size]
res.append((X_train, Y_train, X_test, Y_test))
return res
def test(X_test, Y_test, tree, log=False): # モデルをテスト
Xsubset = X_test
res = tree.predict(Xsubset)
allcount = len(Xsubset)
correct = np.sum(res == Y_test)
error = 1 - correct / allcount
if log:
print("all:", len(Xsubset))
print("correct:", correct)
return error
def main():
from sklearn.datasets import load_iris
# Data-set
size = 4
searched = []
iris = load_iris()
X, Y = iris.data[:, :2], iris.target
splited = split_data(X, Y, size)
# Search-HyperParam
models = {}
for pruning_count in [i for i in range(20)]:
models[pruning_count] = []
# Train & Test
errors = 0
for i in range(size):
X_train, Y_train, X_test, Y_test = splited[i]
# Decision-tree
tree = Tree(np.array(X_train), np.array(Y_train), max_depth=100)
tree.fit()
# Pruning!!
for _ in range(pruning_count):
g = tree.prune()
error = test(X_test, Y_test, tree)
models[pruning_count].append((error, tree))
errors += error
tree.draw()
plt.scatter(X[:, 0], X[:, 1], c=Y)
plt.savefig("images/figure{}.png".format(pruning_count))
plt.clf()
errors /= size
print("{:.2f}, {}".format(errors, pruning_count))
searched.append((errors, pruning_count))
# Select-HyperParam
searched.sort(key=lambda x: x[0])
error, pruning_count = searched[0]
print("pruning_count:", pruning_count)
# Select-model
models[pruning_count].sort(key=lambda x: x[0])
error, tree = models[pruning_count][0]
# Plot & Draw
mesh = 200
mx, my = np.meshgrid(np.linspace(X[:, 0].min()-1, X[:, 0].max()+1, mesh), np.linspace(X[:, 1].min()-1, X[:, 1].max()+1, mesh))
mX = np.stack([mx.ravel(), my.ravel()], 1)
mz = tree.predict(mX).reshape(mesh, mesh)
plt.scatter(X[:, 0], X[:, 1], c=Y)
plt.contourf(mx, my, mz, alpha=0.4, cmap='plasma', zorder=0)
plt.savefig("result.png")
tree.draw()
plt.savefig("result_with_lines.png")
r = test(X,Y,tree)
print("error:", round(error, 3))
print("r:", round(r, 3))
if __name__ == "__main__":
start = time.time()
main()
elapsed_time = time.time() - start
print("time:{:.2f}[sec]".format(elapsed_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment