Created
February 23, 2019 03:13
-
-
Save a-mitani/6c8a90905eb4ad934d1551ec523c01c4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# ジニ係数の計算 | |
def gini_impurity(datasets): | |
data_all = np.concatenate(datasets, axis=0) # データセットを結合 | |
n_all = len(data_all) # 全サンプル数 | |
class_set = set(data_all) #データセットに含まれているユニークなクラスのセットを取り出し | |
if(len(class_set) == 1): # クラスが1つしか含まれてなければ計算するまでもなくgini係数は0 | |
return 0.0 | |
gini = 0.0 | |
for dataset in datasets: | |
size = len(dataset) | |
# 分割後のデータセットの要素数がゼロならスキップ(空要素はジニ係数には影響しない) | |
if size == 0: | |
continue | |
score = 0.0 | |
for class_val in class_set: | |
p = np.sum(dataset == class_val) / size # class_valに一致する要素の数を全体数で割る。 | |
score += p * p | |
gini += (1.0 - score) * (float(size) / float(n_all)) # はじめてのパターン認識(11.11)式の後半 | |
return gini |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
a = np.array([0, 0, 0, 0, 0, 0]) | |
print(gini_impurity([a])) | |
a = np.array([0, 1, 1, 0, 1, 0]) | |
print(gini_impurity([a])) | |
a = np.array([1, 1, 1, 1, 1, 1]) | |
print(gini_impurity([a])) | |
a = np.array([0, 0, 0, 0]) | |
b = np.array([1, 1, 1, 1]) | |
print(gini_impurity([a, b])) | |
a = np.array([0, 0, 1, 1]) | |
b = np.array([1, 1, 0, 0]) | |
print(gini_impurity([a, b])) | |
a = np.array([0, 0, 1, 1]) | |
b = np.array([0, 0, 0, 0]) | |
print(gini_impurity([a, b])) | |
a = np.array([0, 0]) | |
b = np.array([1, 1, 1, 0, 0, 0]) | |
print(gini_impurity([a, b])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment