Created
December 18, 2021 09:36
-
-
Save YuigaWada/22c8138daf7ad07a680767585944e2a2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import time | |
PLOT = 1 | |
class Node: # 決定木のノード | |
def __init__(self, X, Y, depth=0): | |
self.left = None | |
self.right = None | |
self.feature = None | |
self.threshold = None | |
self.X = X | |
self.Y = Y | |
self.features = [i for i in range(np.shape(X)[1])] | |
self.depth = depth | |
self.class_table = np.bincount(Y) | |
self.class_label = np.argmax(self.class_table) # arg{max{count(x)}} | |
def get_error(self): # 1 - max{P(C_i|node=self)} | |
table = self.class_table | |
return sum(table) > 0 if 1 - max(table) / sum(table) else 1 | |
class Tree: # 決定木 | |
def __init__(self, X, Y, max_depth): | |
self.root = Node(X, Y) | |
self.X = X | |
self.Y = Y | |
self.max_depth = max_depth | |
min_max_table = [] | |
for i in range(np.shape(X)[1]): | |
min_max_table.append([min(X[:, i]), max(X[:, i])]) | |
self.min_max_table = min_max_table | |
def gini(self, pair): # ジニ係数を計算 | |
if sum(pair) == 0: | |
return 0.0 | |
probabilities = pair / sum(pair) | |
gini = 1 - sum(probabilities**2) | |
return gini | |
def get_mean_array(self, x): # 各中点を計算 | |
if len(x) == 0: | |
return [] | |
return np.convolve(x, np.ones(2), mode="valid") / 2 | |
def get_gini(self, X, Y, feature=None, value=None, separate_type=None): # 分割後のジニ係数を計算 | |
if separate_type == None: | |
count_table = np.bincount(Y) | |
elif separate_type == "left": | |
count_table = np.bincount(Y[X[:, feature] <= value]) | |
elif separate_type == "right": | |
count_table = np.bincount(Y[X[:, feature] > value]) | |
count = sum(count_table) | |
half_gini = self.gini(count_table) | |
return count, half_gini | |
def split(self, node): # 適切な分割点を探索 | |
X, Y = node.X, node.Y | |
_, gini = self.get_gini(X, Y) | |
maximum = -1 | |
best = None | |
threshold = None | |
for feature in node.features: | |
if len(np.unique(X[:, feature])) <= 1: | |
continue | |
ix = X[:, feature].argsort() | |
mean_array = self.get_mean_array(np.unique(X[ix, feature])) | |
for value in mean_array: # 各中点で分割してジニ係数を計算 | |
lcount, lgini =\ | |
self.get_gini(X[ix], Y[ix], feature, value, "left") | |
rcount, rgini =\ | |
self.get_gini(X[ix], Y[ix], feature, value, "right") | |
count = lcount + rcount | |
l_probability = lcount / count | |
r_probability = rcount / count | |
gain = gini - (l_probability * lgini + r_probability * rgini) | |
assert gain >= 0 | |
if gain > maximum: # chmax | |
best = feature | |
threshold = value | |
maximum = gain | |
if best is None: | |
assert len(np.unique(X[:, feature])) <= 1 | |
return (best, threshold) | |
def fit(self): # 決定木を生成 | |
self.grow(node=self.root) | |
def grow(self, node): # node以降で木を成長させる | |
if node.depth >= self.max_depth: | |
return | |
X, Y = node.X, node.Y | |
feature, threshold = self.split(node) | |
if feature is None: | |
return | |
node.feature = feature | |
node.threshold = threshold | |
il, ir = X[:, feature] <= threshold, X[:, feature] > threshold | |
ndepth = node.depth + 1 | |
node.left = Node( | |
X[il], | |
Y[il], | |
depth=ndepth, | |
) | |
node.right = Node( | |
X[ir], | |
Y[ir], | |
depth=ndepth, | |
) | |
self.grow(node.left) | |
self.grow(node.right) | |
def prune(self): # 木の剪定 (arg{min{g(x)}}の子孫を全て削除する) | |
stack = [self.root] | |
N = len(self.X) | |
INF = 1 << 30 | |
mg = (INF, None) # min_g, arg{min_g} | |
while len(stack): # DFS | |
current = stack.pop() | |
has_left = current.left is not None | |
has_right = current.right is not None | |
is_terminal = not has_left and not has_right | |
if not is_terminal: # 非終端ノードはg(node)を計算 | |
expr = self.g(current) | |
if expr < mg[0] and current != self.root: | |
mg = (expr, current) | |
if has_left: | |
stack.append(current.left) | |
if has_right: | |
stack.append(current.right) | |
alpha, target = mg | |
if target is not None: | |
target.left = None | |
target.right = None | |
return alpha | |
def get_terminal_error(self, node): # 終端nodeの誤り率 = M(t) / N , M(t):=総誤り数 | |
N = len(self.X) | |
failure = 0 | |
for y in node.Y: | |
failure += y != node.class_label | |
return failure / N | |
def get_nonterminal_error(self, node): # 非終端nodeの誤り率 = 再代入誤り率 * 周辺確率 | |
p_t = len(node.X) / len(self.X) # 周辺確率 | |
R_t = node.get_error() * p_t | |
return R_t | |
def g(self, node): # g(t) = node_tのリンクの強さ | |
R_t = self.get_nonterminal_error(node) | |
stack = [node] | |
terminals = 0 | |
R_T = 0 | |
while len(stack): # DFS | |
current = stack.pop() | |
has_left = current.left is not None | |
has_right = current.right is not None | |
if has_left: | |
stack.append(current.left) | |
if has_right: | |
stack.append(current.right) | |
if not has_left and not has_right: | |
R_T += self.get_terminal_error(current) | |
terminals += 1 | |
alpha = R_t - R_T | |
alpha /= (terminals - 1) | |
return alpha | |
def draw(self): # 分割区間をグラフに表示 | |
if not PLOT: | |
return | |
stack = [self.root] | |
terminals = [] | |
while len(stack): # DFS | |
current = stack.pop() | |
if current.left is not None: | |
stack.append(current.left) | |
if current.right is not None: | |
stack.append(current.right) | |
if current.feature is None: | |
continue | |
if not current.feature: | |
plt.plot([current.threshold, current.threshold], | |
self.min_max_table[~current.feature], color="green") | |
else: | |
plt.plot(self.min_max_table[~current.feature], [current.threshold, | |
current.threshold], color="green") | |
def predict(self, X): # 識別器 | |
predictions = np.zeros_like(X[:, 0]) | |
for i, x in enumerate(X): | |
values = {} | |
for feature in self.root.features: | |
values.update({feature: x[feature]}) | |
current = self.root | |
while current.depth < self.max_depth and current.feature is not None: | |
next = current.left if values[current.feature] < current.threshold else current.right | |
if next is not None: | |
current = next | |
else: | |
break | |
predictions[i] = current.class_label | |
return predictions | |
# Main | |
def split_data(X, Y, size): # データを分割 | |
window = len(Y) // size | |
train_size, test_size = len(Y) - window, window | |
res = [] | |
_X, _Y = np.concatenate((X, X)), np.concatenate((Y, Y)) | |
for i in range(size): | |
offset = window * i | |
test_tail = offset + window | |
X_test, Y_test =\ | |
_X[offset:test_tail], _Y[offset:test_tail] | |
X_train, Y_train =\ | |
_X[test_tail:test_tail + train_size], _Y[test_tail:test_tail+train_size] | |
res.append((X_train, Y_train, X_test, Y_test)) | |
return res | |
def test(X_test, Y_test, tree, log=False): # モデルをテスト | |
Xsubset = X_test | |
res = tree.predict(Xsubset) | |
allcount = len(Xsubset) | |
correct = np.sum(res == Y_test) | |
error = 1 - correct / allcount | |
if log: | |
print("all:", len(Xsubset)) | |
print("correct:", correct) | |
return error | |
def main(): | |
from sklearn.datasets import load_iris | |
# Data-set | |
size = 4 | |
searched = [] | |
iris = load_iris() | |
X, Y = iris.data[:, :2], iris.target | |
splited = split_data(X, Y, size) | |
# Search-HyperParam | |
models = {} | |
for pruning_count in [i for i in range(20)]: | |
models[pruning_count] = [] | |
# Train & Test | |
errors = 0 | |
for i in range(size): | |
X_train, Y_train, X_test, Y_test = splited[i] | |
# Decision-tree | |
tree = Tree(np.array(X_train), np.array(Y_train), max_depth=100) | |
tree.fit() | |
# Pruning!! | |
for _ in range(pruning_count): | |
g = tree.prune() | |
error = test(X_test, Y_test, tree) | |
models[pruning_count].append((error, tree)) | |
errors += error | |
tree.draw() | |
plt.scatter(X[:, 0], X[:, 1], c=Y) | |
plt.savefig("images/figure{}.png".format(pruning_count)) | |
plt.clf() | |
errors /= size | |
print("{:.2f}, {}".format(errors, pruning_count)) | |
searched.append((errors, pruning_count)) | |
# Select-HyperParam | |
searched.sort(key=lambda x: x[0]) | |
error, pruning_count = searched[0] | |
print("pruning_count:", pruning_count) | |
# Select-model | |
models[pruning_count].sort(key=lambda x: x[0]) | |
error, tree = models[pruning_count][0] | |
# Plot & Draw | |
mesh = 200 | |
mx, my = np.meshgrid(np.linspace(X[:, 0].min()-1, X[:, 0].max()+1, mesh), np.linspace(X[:, 1].min()-1, X[:, 1].max()+1, mesh)) | |
mX = np.stack([mx.ravel(), my.ravel()], 1) | |
mz = tree.predict(mX).reshape(mesh, mesh) | |
plt.scatter(X[:, 0], X[:, 1], c=Y) | |
plt.contourf(mx, my, mz, alpha=0.4, cmap='plasma', zorder=0) | |
plt.savefig("result.png") | |
tree.draw() | |
plt.savefig("result_with_lines.png") | |
r = test(X,Y,tree) | |
print("error:", round(error, 3)) | |
print("r:", round(r, 3)) | |
if __name__ == "__main__": | |
start = time.time() | |
main() | |
elapsed_time = time.time() - start | |
print("time:{:.2f}[sec]".format(elapsed_time)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment