Skip to content

Instantly share code, notes, and snippets.

@Keycatowo
Last active August 28, 2020 10:38
Show Gist options
  • Save Keycatowo/32914afdf12ecf1b50b8d81c0e67ff43 to your computer and use it in GitHub Desktop.
Save Keycatowo/32914afdf12ecf1b50b8d81c0e67ff43 to your computer and use it in GitHub Desktop.
import time
import random
import sys
from pathlib import Path, PureWindowsPath
for i in range(1):
print("loop: ", i)
import lib.score_function
from importlib import reload
reload(lib.score_function)
from lib.score_function import print_score
model_list = []
# epoch = random.randrange(5, 80)
epoch = 10
for i in range(0,len(X_dict_of_list['train']),5):
print('multiple:', i+1)
print(X_dict_of_list['train_res'][i])
# seed = random.randrange(sys.maxsize)
# seed = 6887378808282378165
seed = 1
print("Seed was:", seed)
class_weights=[1.0, 1.0, 1.0]
model = splited_DNN(X_dict_of_list['train_res'][i], y_dict_of_list['train_res'][i], num_or_size_splits=docvec_size, bottleneck_size=60, class_weights=class_weights, seed=seed)
valid_accuracy = model.train(X_dict_of_list['train_res'][i], y_dict_of_list['train_res'][i], X_dict_of_list['val'][i], y_dict_of_list['val'][i], epoch=epoch)
print("test with data augmentation: ")
accuracy, roc_auc = print_score(model, X_dict_of_list['test'][i], y_dict_of_list['test'][i], show_threshold=False)
print('\n')
# print("test without data augmentation: ")
# print_score(model, X_dict_of_list['test'][0], y_dict_of_list['test'][0], show_threshold=False)
# print('\n')
current_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
roc_auc_str = {k:round(v,2) if isinstance(v,float) else v for k,v in roc_auc.items()}
roc_auc_str = ''.join('{}-{}_'.format(key, val) for key, val in roc_auc_str.items())
#"roc_auc_{}.ckpt".format(roc_auc_str)
class_weights_str = {k:round(v,2) if isinstance(v,float) else v for k,v in enumerate(class_weights)}
class_weights_str = ''.join('{}-{}_'.format(key, val) for key, val in class_weights_str.items())
class_num = y_dict_of_list['train'][i].shape[1]
modelFilePath = ("./models/SplitDNN判決結果分類/"
"SplitDNN({}_class)"
"_train{}_val{}_test{}"
"_class_weights{}_epoch{}"
"_valid_accuracy{:.4f}_accuracy{:.4f}"
"_roc_auc{}_seed{}_{}.ckpt").format(
class_num,
len(X_dict_of_list['train'][i]),
len(X_dict_of_list['val'][i]),
len(X_dict_of_list['test'][i]),
class_weights_str,
epoch,
valid_accuracy,
accuracy,
roc_auc_str,
seed,
current_time)
modelFilePath = Path(modelFilePath)
modelFilePath = Path(modelFilePath.absolute())
# to avoid model.save() crached by filename too long issue under windows,
# use absolute path prefixed with u'\\\\?\\' with PathLib,
# then convert it to string as model.save()'s path argument.
unc_prefix = PureWindowsPath(u'\\\\?\\')
unc_prefix = Path(unc_prefix)
unc_modelFilePath = Path(str(unc_prefix) + str(modelFilePath))
if unc_modelFilePath.parents[0].exists():
print(str(unc_modelFilePath))
model.save(str(unc_modelFilePath))
else:
print(str(modelFilePath))
model.save(str(modelFilePath))
# model.load only accept normal short path, not unc path,
# so just output latest.ckpt as short filename
latest_path = str((unc_modelFilePath.parents[0])/'latest.ckpt').replace(str(unc_prefix), '')
print(latest_path)
model.save(latest_path)
model_list.append(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment