-
-
Save vortexkd/708ad407a3c6ca095a02d10db43e5375 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
def process_input(input_data): | |
# data増強をこちらに | |
return input_data | |
def convert_to_class(label, classes): | |
# ラベル作成に必要なものはこちら | |
#one_hotベクトル作成なども必要に応じてこちらにできます。 | |
return classes.index(label) | |
def count_lines(file_path:str): | |
count=0 | |
with open(file_path, 'r') as f: | |
for line in f: | |
count += 1 | |
return count | |
# ここでは、ラベル作成のために必要であろうと思ってclass_mapも追加しておきました、ラベルがすでに数字がされていましたら要りませんです! | |
def generate_input_data(class_list: List, path: str, batch_size=32): | |
i = 0 | |
max_i = count_lines(path) | |
while True: | |
# リセットします | |
data = pd.DataFrame() | |
add_n = batch_size | |
# たまに、うまく行かないのもあるので、それを外て埋め直したいために、whileに入れる | |
while len(data) < batch_size: | |
# batch_sizeに合わせた行数を読みます。skiprows=range(1, i)は項目名(最初のrow)を読むためです. | |
temp_data = pd.read_csv(path, dtype=str, skiprows=range(1, i), | |
nrows=add_n) | |
temp_data['input_x'] = \ | |
temp_data.RawInput1.map(lambda d: process_input(d)) | |
temp_data['label_y'] = temp_data.Label.map( | |
lambda label: convert_to_class(label, class_list)) | |
i += add_n | |
if i > max_i: # ファイルが終わったら最初からやり直す | |
i = 0 | |
# うまく行かないやつは学習に欲しくないから、外します。 | |
temp_data.dropna(inplace=True) | |
add_n = batch_size - len(temp_data) | |
data = pd.concat([data, temp_data]) | |
yield data['input_x'], data['input_y'] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment