Skip to content

Instantly share code, notes, and snippets.

@terakun
Created December 6, 2018 12:14
Show Gist options
  • Save terakun/4915b8d06b0b56f7f0b0672f8221ccad to your computer and use it in GitHub Desktop.
Save terakun/4915b8d06b0b56f7f0b0672f8221ccad to your computer and use it in GitHub Desktop.
optunaを使って他プログラムのパラメータを最適化する
#!/usr/bin/python
# coding: UTF-8
# 実行ファイルはコマンドライン引数としてパラメータを受け取り,目的関数値のみを標準出力に出力するものとする
# $ ./a.out 1.2 10
# 42
# configファイルには
# [varname] [vartype] [range]
# を書いておく
# 例えばコマンドライン引数としてdouble x ( 0 <= x <= 1.5 ),int n ( 1 <= n <= 100 )を受け取るプログラム(実行ファイル:a.out)なら
# config.txt:
#
# x uniform 0 1.5
# n int 1 100
#
# のように書いて,
# python opt.py config.txt ./a.out
# とすれば最適化してくれる
import subprocess
import sys
import optuna
def res_cmd(cmd):
return subprocess.Popen(cmd, stdout=subprocess.PIPE,shell=True).communicate()[0]
def set_objective(configfilename,execfilename):
f = open(configfilename)
lines = f.readlines()
f.close()
varinfo = []
for line in lines:
param = line.split()
varname = param[0]
vartype = param[1]
varparam = param[2:]
varinfo.append((varname,vartype,varparam))
def objective(trial):
varlist = []
for (varname,vartype,varparam) in varinfo:
if vartype == "int":
varlist.append(trial.suggest_int(varname,int(varparam[0]),int(varparam[1])))
elif vartype == "uniform":
varlist.append(trial.suggest_uniform(varname,float(varparam[0]),float(varparam[1])))
elif vartype == "loguniform":
varlist.appned(trial.suggest_loguniform(varname,float(varparam[0]),float(varparam[1])))
elif vartype == "discrete_uniform":
varlist.append(trial.suggest_discrete_uniform(varname,float(varparam[0]),float(varparam[1]),float(varparam[2])))
valstrlist = map(lambda var: str(var), varlist)
cmd = execfilename+" "+(" ".join(valstrlist))
ret = res_cmd(cmd)
return float(ret)
return objective
def main():
argv = sys.argv
argc = len(argv)
if argc < 4:
print(argv[0]+" [config file] [executable file] [n_trials]")
study = optuna.create_study()
study.optimize(set_objective(argv[1],argv[2]), n_trials=int(argv[3]))
print(study.best_params)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment