Skip to content

Instantly share code, notes, and snippets.

Last active May 30, 2019
What would you like to do?
from typing import List
def process_input(input_data):
# data増強をこちらに
return input_data
def convert_to_class(label, classes):
# ラベル作成に必要なものはこちら
return classes.index(label)
def count_lines(file_path:str):
with open(file_path, 'r') as f:
for line in f:
count += 1
return count
# ここでは、ラベル作成のために必要であろうと思ってclass_mapも追加しておきました、ラベルがすでに数字がされていましたら要りませんです!
def generate_input_data(class_list: List, path: str, batch_size=32):
i = 0
max_i = count_lines(path)
while True:
# リセットします
data = pd.DataFrame()
add_n = batch_size
# たまに、うまく行かないのもあるので、それを外て埋め直したいために、whileに入れる
while len(data) < batch_size:
# batch_sizeに合わせた行数を読みます。skiprows=range(1, i)は項目名(最初のrow)を読むためです.
temp_data = pd.read_csv(path, dtype=str, skiprows=range(1, i),
temp_data['input_x'] = \ d: process_input(d))
temp_data['label_y'] =
lambda label: convert_to_class(label, class_list))
i += add_n
if i > max_i: # ファイルが終わったら最初からやり直す
i = 0
# うまく行かないやつは学習に欲しくないから、外します。
add_n = batch_size - len(temp_data)
data = pd.concat([data, temp_data])
yield data['input_x'], data['input_y']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment