Skip to content

Instantly share code, notes, and snippets.

@vortexkd
Last active May 30, 2019 04:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vortexkd/708ad407a3c6ca095a02d10db43e5375 to your computer and use it in GitHub Desktop.
Save vortexkd/708ad407a3c6ca095a02d10db43e5375 to your computer and use it in GitHub Desktop.
from typing import List
def process_input(input_data):
# data増強をこちらに
return input_data
def convert_to_class(label, classes):
# ラベル作成に必要なものはこちら
#one_hotベクトル作成なども必要に応じてこちらにできます。
return classes.index(label)
def count_lines(file_path:str):
count=0
with open(file_path, 'r') as f:
for line in f:
count += 1
return count
# ここでは、ラベル作成のために必要であろうと思ってclass_mapも追加しておきました、ラベルがすでに数字がされていましたら要りませんです!
def generate_input_data(class_list: List, path: str, batch_size=32):
i = 0
max_i = count_lines(path)
while True:
# リセットします
data = pd.DataFrame()
add_n = batch_size
# たまに、うまく行かないのもあるので、それを外て埋め直したいために、whileに入れる
while len(data) < batch_size:
# batch_sizeに合わせた行数を読みます。skiprows=range(1, i)は項目名(最初のrow)を読むためです.
temp_data = pd.read_csv(path, dtype=str, skiprows=range(1, i),
nrows=add_n)
temp_data['input_x'] = \
temp_data.RawInput1.map(lambda d: process_input(d))
temp_data['label_y'] = temp_data.Label.map(
lambda label: convert_to_class(label, class_list))
i += add_n
if i > max_i: # ファイルが終わったら最初からやり直す
i = 0
# うまく行かないやつは学習に欲しくないから、外します。
temp_data.dropna(inplace=True)
add_n = batch_size - len(temp_data)
data = pd.concat([data, temp_data])
yield data['input_x'], data['input_y']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment