Skip to content

Instantly share code, notes, and snippets.

@a-mitani
Created February 23, 2019 03:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save a-mitani/6c8a90905eb4ad934d1551ec523c01c4 to your computer and use it in GitHub Desktop.
Save a-mitani/6c8a90905eb4ad934d1551ec523c01c4 to your computer and use it in GitHub Desktop.
import numpy as np
# ジニ係数の計算
def gini_impurity(datasets):
data_all = np.concatenate(datasets, axis=0) # データセットを結合
n_all = len(data_all) # 全サンプル数
class_set = set(data_all) #データセットに含まれているユニークなクラスのセットを取り出し
if(len(class_set) == 1): # クラスが1つしか含まれてなければ計算するまでもなくgini係数は0
return 0.0
gini = 0.0
for dataset in datasets:
size = len(dataset)
# 分割後のデータセットの要素数がゼロならスキップ(空要素はジニ係数には影響しない)
if size == 0:
continue
score = 0.0
for class_val in class_set:
p = np.sum(dataset == class_val) / size # class_valに一致する要素の数を全体数で割る。
score += p * p
gini += (1.0 - score) * (float(size) / float(n_all)) # はじめてのパターン認識(11.11)式の後半
return gini
a = np.array([0, 0, 0, 0, 0, 0])
print(gini_impurity([a]))
a = np.array([0, 1, 1, 0, 1, 0])
print(gini_impurity([a]))
a = np.array([1, 1, 1, 1, 1, 1])
print(gini_impurity([a]))
a = np.array([0, 0, 0, 0])
b = np.array([1, 1, 1, 1])
print(gini_impurity([a, b]))
a = np.array([0, 0, 1, 1])
b = np.array([1, 1, 0, 0])
print(gini_impurity([a, b]))
a = np.array([0, 0, 1, 1])
b = np.array([0, 0, 0, 0])
print(gini_impurity([a, b]))
a = np.array([0, 0])
b = np.array([1, 1, 1, 0, 0, 0])
print(gini_impurity([a, b]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment