Created
August 30, 2019 08:26
-
-
Save Yunaka12/99dddb2df0b52f9d627e0ba34f6f6b4d to your computer and use it in GitHub Desktop.
各クラスの比率を保ったまま、学習データ、テストデータ、検証用データにk分割
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
#配列をn分割 | |
def split_data(data,split_num): | |
split_arr=[] | |
for i in range(0,len(data),split_num): | |
split_tmp_arr = data[i:i+split_num] | |
split_tmp_arr = np.array(split_tmp_arr).flatten() | |
split_arr.append(split_tmp_arr.tolist()) | |
return split_arr | |
#split_arrの配列の要素を1つずつ抜き出して新しい配列を作製 | |
def recomb_data(data,each_class_size): | |
ext_index=0 | |
recomb_arr=[] | |
while ext_index<each_class_size: | |
recomb_tmp_arr=[] | |
for i in range(len(data)): | |
recomb_tmp_arr.append(data[i][ext_index]) | |
ext_index +=1 | |
recomb_arr.append(recomb_tmp_arr) | |
return recomb_arr | |
#recomb_arrの配列を前から順番にk個の組に分ける | |
def make_k_class(data,k): | |
split_size = int(len(data)/k) | |
k_arr = split_data(data,split_size) | |
return k_arr | |
#1:1:k-1に分割 | |
def k_comb(k,k_arr): | |
k_comb_test_arr=[] | |
k_comb_val_arr=[] | |
k_comb_train_arr=[] | |
for i in range(k): | |
test_arr = k_arr[i] | |
val_arr = k_arr[i-1] | |
if i==0: | |
train_arr = k_arr[i+1:i-1] | |
else: | |
train_arr = k_arr[0:i-1] +k_arr[i+1::] | |
k_comb_test_arr.append(test_arr) | |
k_comb_val_arr.append(val_arr) | |
k_comb_train_arr.append(train_arr) | |
return k_comb_test_arr,k_comb_val_arr,k_comb_train_arr | |
def strait_kFold(k,data,each_class_size,index): | |
if index is True: | |
data = [i for i in range(len(data))] | |
split_arr = split_data(data,each_class_size) | |
recomb_arr = recomb_data(split_arr,each_class_size) | |
k_arr = make_k_class(recomb_arr,k) | |
test_arr,val_arr,train_arr = k_comb(k,k_arr) | |
#出力 | |
print("【全データ】") | |
print(data) | |
print("") | |
print("【各クラスに属する値】") | |
for i in range(len(split_arr)): | |
print("Class-{}:{}".format(i,split_arr[i])) | |
print("") | |
print("【1つの配列に各クラスのデータが1つずつ含まれるように組み換え】") | |
for i in range(len(recomb_arr)): | |
print("各クラスの{}番目を集めた配列:{}".format(i,recomb_arr[i])) | |
print("") | |
print("【上記の配列をk={}個の組に分ける】".format(k)) | |
print(k_arr) | |
print("") | |
print("【クロスバリデーションで使う配列の組み合わせ】") | |
for i in range(len(test_arr)): | |
print("k={}".format(i+1)) | |
print("{}:{}".format("test",test_arr[i])) | |
print("{}:{}".format("validation",val_arr[i])) | |
print("{}:{}".format("train",train_arr[i])) | |
print("") | |
return test_arr,val_arr,train_arr | |
X=["A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X"] | |
test,validation,train= strait_kFold(k=3,data=X,each_class_size=6,index=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment