Skip to content

Instantly share code, notes, and snippets.

@Yunaka12
Created August 30, 2019 08:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Yunaka12/99dddb2df0b52f9d627e0ba34f6f6b4d to your computer and use it in GitHub Desktop.
Save Yunaka12/99dddb2df0b52f9d627e0ba34f6f6b4d to your computer and use it in GitHub Desktop.
各クラスの比率を保ったまま、学習データ、テストデータ、検証用データにk分割
import numpy as np
#配列をn分割
def split_data(data,split_num):
split_arr=[]
for i in range(0,len(data),split_num):
split_tmp_arr = data[i:i+split_num]
split_tmp_arr = np.array(split_tmp_arr).flatten()
split_arr.append(split_tmp_arr.tolist())
return split_arr
#split_arrの配列の要素を1つずつ抜き出して新しい配列を作製
def recomb_data(data,each_class_size):
ext_index=0
recomb_arr=[]
while ext_index<each_class_size:
recomb_tmp_arr=[]
for i in range(len(data)):
recomb_tmp_arr.append(data[i][ext_index])
ext_index +=1
recomb_arr.append(recomb_tmp_arr)
return recomb_arr
#recomb_arrの配列を前から順番にk個の組に分ける
def make_k_class(data,k):
split_size = int(len(data)/k)
k_arr = split_data(data,split_size)
return k_arr
#1:1:k-1に分割
def k_comb(k,k_arr):
k_comb_test_arr=[]
k_comb_val_arr=[]
k_comb_train_arr=[]
for i in range(k):
test_arr = k_arr[i]
val_arr = k_arr[i-1]
if i==0:
train_arr = k_arr[i+1:i-1]
else:
train_arr = k_arr[0:i-1] +k_arr[i+1::]
k_comb_test_arr.append(test_arr)
k_comb_val_arr.append(val_arr)
k_comb_train_arr.append(train_arr)
return k_comb_test_arr,k_comb_val_arr,k_comb_train_arr
def strait_kFold(k,data,each_class_size,index):
if index is True:
data = [i for i in range(len(data))]
split_arr = split_data(data,each_class_size)
recomb_arr = recomb_data(split_arr,each_class_size)
k_arr = make_k_class(recomb_arr,k)
test_arr,val_arr,train_arr = k_comb(k,k_arr)
#出力
print("【全データ】")
print(data)
print("")
print("【各クラスに属する値】")
for i in range(len(split_arr)):
print("Class-{}:{}".format(i,split_arr[i]))
print("")
print("【1つの配列に各クラスのデータが1つずつ含まれるように組み換え】")
for i in range(len(recomb_arr)):
print("各クラスの{}番目を集めた配列:{}".format(i,recomb_arr[i]))
print("")
print("【上記の配列をk={}個の組に分ける】".format(k))
print(k_arr)
print("")
print("【クロスバリデーションで使う配列の組み合わせ】")
for i in range(len(test_arr)):
print("k={}".format(i+1))
print("{}:{}".format("test",test_arr[i]))
print("{}:{}".format("validation",val_arr[i]))
print("{}:{}".format("train",train_arr[i]))
print("")
return test_arr,val_arr,train_arr
X=["A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X"]
test,validation,train= strait_kFold(k=3,data=X,each_class_size=6,index=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment