Skip to content

Instantly share code, notes, and snippets.

@ykochi
Forked from terakun/opt.py
Created December 7, 2018 14:30
Show Gist options
  • Save ykochi/0bf2f1ac8b84fce6bcd970e76e2f4127 to your computer and use it in GitHub Desktop.
Save ykochi/0bf2f1ac8b84fce6bcd970e76e2f4127 to your computer and use it in GitHub Desktop.
optunaを使って他プログラムのパラメータを最適化する
#!/usr/bin/python
# coding: UTF-8
# 実行ファイルはコマンドライン引数としてパラメータを受け取り,目的関数値のみを標準出力に出力するものとする
# $ ./a.out 1.2 10
# 42
# configファイルには
# [varname] [vartype] [range]
# を書いておく
# 例えばコマンドライン引数としてdouble x ( 0 <= x <= 1.5 ),int n ( 1 <= n <= 100 )を受け取るプログラム(実行ファイル:a.out)なら
# config.txt:
#
# x uniform 0 1.5
# n int 1 100
#
# のように書いて,
# python opt.py config.txt ./a.out
# とすれば最適化してくれる
import asyncio
import sys
import argparse
import optuna
async def res_cmd(cmd, params, trial):
proc = await asyncio.create_subprocess_exec(
cmd, *params, stdout=asyncio.subprocess.PIPE)
# Read one line of output.
step = 0
while True:
data = await proc.stdout.readline()
if not data:
break
value = float(data.decode("ascii").rstrip())
trial.report(value, step)
if trial.should_prune(step):
proc.kill()
raise optuna.structs.TrialPruned()
step += 1
# Wait for the subprocess exit.
await proc.wait()
return value
def set_objective(configfilename, execfilename):
with open(configfilename) as f:
lines = f.readlines()
varinfo = []
for line in lines:
param = line.split()
varname = param[0]
vartype = param[1]
varparam = param[2:]
varinfo.append((varname, vartype, varparam))
def objective(trial):
varlist = []
for varname, vartype, varparam in varinfo:
if vartype == "int":
varlist.append(trial.suggest_int(varname,int(varparam[0]),int(varparam[1])))
elif vartype == "uniform":
varlist.append(trial.suggest_uniform(varname,float(varparam[0]),float(varparam[1])))
elif vartype == "loguniform":
varlist.appned(trial.suggest_loguniform(varname,float(varparam[0]),float(varparam[1])))
elif vartype == "discrete_uniform":
varlist.append(trial.suggest_discrete_uniform(varname,float(varparam[0]),float(varparam[1]),float(varparam[2])))
valstrlist = [str(v) for v in varlist]
loop = asyncio.get_event_loop()
ret = loop.run_until_complete(res_cmd(execfilename, valstrlist, trial))
return ret
return objective
def main():
parser = argparse.ArgumentParser()
parser.add_argument("config")
parser.add_argument("executable")
parser.add_argument("n_trials", type=int)
args = parser.parse_args()
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(set_objective(args.config, args.executable), n_trials=args.n_trials)
print(study.best_params)
if __name__ == "__main__":
main()
@ykochi
Copy link
Author

ykochi commented Dec 7, 2018

pruner対応しました

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment