-
-
Save yumenotobira/233d122b64383c0c78d5eb4b539e1e46 to your computer and use it in GitHub Desktop.
beatmania IIDXの譜面からスコア分布を学習
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf-8 | |
import pathlib | |
import numpy as np | |
from keras.models import model_from_json, Model | |
from keras.layers import Dense, Masking, Input, Activation, concatenate, Dropout | |
from keras.layers.recurrent import LSTM | |
from keras.optimizers import Adam | |
from keras import optimizers, regularizers, callbacks | |
from keras import backend as K | |
import os | |
import pickle | |
import random | |
def norm_dist(x, mean, sd): | |
return K.exp(-K.square(x - mean)/(2 * K.square(sd))) / (sd * (2 * np.pi)**(1/2)) | |
# Cauchy-Schwarz Divergence | |
def csd_loss(y_true, y_pred): | |
delta = 1e-12 | |
weight_true = y_true[:, 0:5] | |
weight_pred = y_pred[:, 0:5] | |
mean_true = y_true[:, 5:10] | |
mean_pred = y_pred[:, 5:10] | |
sd_true = K.sqrt(y_true[:, 10:15] / 100) | |
sd_pred = K.sqrt(y_pred[:, 10:15] / 100) | |
# 第1項 | |
sum_1 = 0 | |
for m in range(5): | |
for k in range(5): | |
sum_1 = sum_1 + weight_pred[:, m] * weight_true[:, k] * norm_dist(mean_pred[:, m], mean_true[:, k], K.sqrt(K.square(sd_pred[:, m]) + K.square(sd_true[:, k]))) | |
term_1 = -K.log(sum_1) | |
# 第2項 | |
sum_2_1 = 0 | |
sum_2_2 = 0 | |
for m in range(5): | |
sum_2_1 = sum_2_1 + K.square(weight_pred[:, m])/((sd_pred[:, m] + delta) * (2 * np.pi)**(1/2)) | |
for l in range(m): | |
sum_2_2 = sum_2_2 + weight_pred[:, m] * weight_pred[:, l] * norm_dist(mean_pred[:, m], mean_pred[:, l], K.sqrt(K.square(sd_pred[:, m]) + K.square(sd_pred[:, l])) + delta) | |
sum_2_2 = 2 * sum_2_2 | |
term_2 = K.log(sum_2_1 + sum_2_2) / 2 | |
# 第3項 | |
sum_3_1 = 0 | |
sum_3_2 = 0 | |
for k in range(5): | |
sum_3_1 = sum_3_1 + K.square(weight_true[:, k])/(sd_true[:, k] * (2 * np.pi)**(1/2)) | |
for l in range(k): | |
sum_3_2 = sum_3_2 + weight_true[:, k] * weight_true[:, l] * norm_dist(mean_true[:, k], mean_true[:, l], K.sqrt(K.square(sd_true[:, k]) + K.square(sd_true[:, l]))) | |
sum_3_2 = 2 * sum_3_2 | |
term_3 = K.log(sum_3_1 + sum_3_2) / 2 | |
return term_1 + term_2 + term_3 | |
time_unit = int(96/4) | |
# 取得するファイル名一覧 | |
posixPath = list(pathlib.Path(path_to_chart_folder).glob('*')) | |
fileNames = list(map(lambda x: x.name, posixPath)) | |
# 譜面データと分布パラメーターの取得 | |
charts = list() | |
params = list() | |
random.seed(3556) | |
random.shuffle(fileNames) | |
index = 1 | |
for fileName in fileNames: | |
print(str(index) + ": " + fileName) | |
index += 1 | |
# 譜面データの取得 | |
chart = np.loadtxt(path_to_chart_folder+ fileName, delimiter=',') | |
# BPMを1/100倍する | |
chart[:, -1] /= 100 | |
# time_unitで分割できるように譜面の長さを調整する | |
for i in range(time_unit - (len(chart) - int(len(chart) / time_unit) * time_unit)): | |
chart = np.append(chart, np.full((1, 9), -1), axis=0) | |
splitted = np.split(chart, len(chart)/time_unit) | |
splitted_chart = np.empty((0, time_unit*9)) | |
for s in splitted: | |
flatten = np.reshape(s, [time_unit*9]) | |
splitted_chart = np.append(splitted_chart, [flatten], axis=0) | |
charts.append(splitted_chart) | |
# パラメーターの取得 | |
f = open(path_to_params_folder + fileName) | |
param = f.read() | |
f.close() | |
param = param.replace('\n', ',') | |
param = param.split(',') | |
param.pop(-1) | |
# sdを100倍してオーダーを上げる | |
param = list(map(lambda x: float(x), param)) | |
param[10:15] = list(map(lambda x: x*100, param[10:15])) | |
l = list() | |
for t in param: | |
l.append(float(t)) | |
params.append(l) | |
# 譜面データの長さを合わせる | |
# 最長の譜面に合わせ、短い分は-1で埋める | |
max_length = max(map(len, charts)) | |
for i in range(len(charts)): | |
padding = [[-1]*(9*time_unit)] * (max_length - len(charts[i])) | |
if len(padding) != 0: | |
charts[i] = np.append(charts[i], padding, axis=0) | |
# 学習用に形式を変換 | |
charts = np.array(charts) | |
charts = charts.astype('float32') | |
params = np.array(params) | |
params = params.astype('float32') | |
# ネットワークの定義 | |
model_input = Input(shape=(None, 9*time_unit)) | |
mask = Masking(mask_value=-1.0)(model_input) | |
weight_decay = 1e-4 | |
mid3w = LSTM(units=128, dropout=0.25, return_sequences=False, | |
kernel_regularizer=regularizers.l2(weight_decay))(mask) | |
mid3m = LSTM(units=128, dropout=0.25, return_sequences=False, | |
kernel_regularizer=regularizers.l2(weight_decay))(mask) | |
mid3s = LSTM(units=128, dropout=0.25, return_sequences=False, | |
kernel_regularizer=regularizers.l2(weight_decay))(mask) | |
for_w = Dense(128, activation='relu')(mid3w) | |
for_w = Dropout(0.25)(for_w) | |
for_m = Dense(128, activation='relu')(mid3m) | |
for_m = Dropout(0.25)(for_m) | |
for_s = Dense(128, activation='relu')(mid3s) | |
for_s = Dropout(0.25)(for_s) | |
additional_dense1 = Dense(5)(for_w) | |
output1 = Activation('softmax')(additional_dense1) | |
additional_dense2 = Dense(5)(for_m) | |
output2 = Activation('sigmoid')(additional_dense2) | |
additional_dense3 = Dense(5)(for_s) | |
output3 = Activation('relu')(additional_dense3) | |
merged = concatenate([output1, output2, output3]) | |
model = Model(inputs=[model_input], outputs=[merged]) | |
model.compile(loss=csd_loss, optimizer=optimizers.Adam(), metrics=['accuracy']) | |
fpath = os.path.join('/content/drive/My Drive/beatmania/' + version + '/model_96/', 'model_weights.{epoch:05d}-{val_loss:.2f}.hdf5') | |
cp_cb = callbacks.ModelCheckpoint(filepath=fpath, monitor='val_loss', verbose=0, save_weights_only = True, save_best_only=True, mode='auto') | |
cp_es = callbacks.EarlyStopping(monitor='val_loss', patience=300, verbose=1, mode='auto') | |
history = model.fit(charts, params, batch_size=32, verbose=2, initial_epoch=0, epochs=2000, validation_split=0.1, callbacks=[cp_cb, cp_es]) | |
# 保存 | |
json_string = model.to_json() | |
open(os.path.join(path_to_result_folder, 'model.json'), 'w').write(json_string) | |
model.save_weights(os.path.join(path_to_result_folder, 'model_weights.hdf5')) | |
with open(os.path.join(path_to_result_folder, 'history.pickle'), 'wb') as file_pi: | |
pickle.dump(history.history, file_pi) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment