Skip to content

Instantly share code, notes, and snippets.

@c-bata
Last active February 19, 2021 18:49
Show Gist options
  • Save c-bata/a2bbd44d9e4ef76f5798bf3d3d69c5a9 to your computer and use it in GitHub Desktop.
Save c-bata/a2bbd44d9e4ef76f5798bf3d3d69c5a9 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"os/signal"
"runtime"
"sync"
"syscall"
"github.com/c-bata/goptuna"
"github.com/c-bata/goptuna/rdb"
"github.com/c-bata/goptuna/tpe"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
func objective(trial goptuna.Trial) (float64, error) {
lmd, err := trial.SuggestLogUniform("lambda", 1e-6, 1)
if err != nil {
return -1, err
}
eta, err := trial.SuggestLogUniform("eta", 1e-6, 1)
if err != nil {
return -1, err
}
latent, err := trial.SuggestInt("latent", 1, 16)
if err != nil {
return -1, err
}
number, err := trial.Number()
if err != nil {
return -1, err
}
jsonMetaPath := fmt.Sprintf("./data/optuna/ffm-meta-%d.json", number)
ctx := trial.GetContext()
cmd := exec.CommandContext(
ctx,
"./ffm-train",
"-p", "./data/valid2.txt",
"--auto-stop", "--auto-stop-threshold", "3",
"-l", fmt.Sprintf("%f", lmd),
"-r", fmt.Sprintf("%f", eta),
"-k", fmt.Sprintf("%d", latent),
"-t", "500",
"--json-meta", jsonMetaPath,
"./data/train2.txt",
)
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
cmd.Stdout = stdout
cmd.Stderr = stderr
err = cmd.Run()
if err != nil {
return -1, fmt.Errorf("ffm-train exited with error: %s", err)
}
var result struct {
BestIteration int `json:"best_iteration"`
BestVALoss float64 `json:"best_va_loss"`
}
jsonStr, err := ioutil.ReadFile(jsonMetaPath)
if err != nil {
return -1, fmt.Errorf("failed to read json: %s", err)
}
err = json.Unmarshal(jsonStr, &result)
if err != nil {
return -1, fmt.Errorf("failed to read json: %s", err)
}
if result.BestIteration == 0 && result.BestVALoss == 0 {
return -1, errors.New("failed to open json meta")
}
_ = trial.SetUserAttr("best_iteration", fmt.Sprintf("%d", result.BestIteration))
_ = trial.SetUserAttr("stdout", stdout.String())
_ = trial.SetUserAttr("stderr", stderr.String())
return result.BestVALoss, nil
}
func main() {
// setup storage
db, err := gorm.Open("sqlite3", "db.sqlite3")
if err != nil {
log.Fatal("failed to open db:", err)
}
defer db.Close()
db.DB().SetMaxOpenConns(1)
storage := rdb.NewStorage(db)
// create a study
study, err := goptuna.LoadStudy(
"goptuna-libffm",
goptuna.StudyOptionStorage(storage),
goptuna.StudyOptionSampler(tpe.NewSampler()),
)
if err != nil {
log.Fatal("failed to create study:", err)
}
// create a context with cancel function
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
study.WithContext(ctx)
// set signal handler
sigch := make(chan os.Signal, 1)
defer close(sigch)
signal.Notify(sigch, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
sig, ok := <-sigch
if !ok {
return
}
cancel()
log.Print("catch a kill signal:", sig.String())
} ()
// run optimize with context
concurrency := runtime.NumCPU() - 1
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := study.Optimize(objective, 1000 / concurrency)
if err != nil {
log.Print("optimize catch error:", err)
}
} ()
}
wg.Wait()
// print best hyper-parameters and the result
v, _ := study.GetBestValue()
params, _ := study.GetBestParams()
log.Printf("Best evaluation=%f (lambda=%f, eta=%f, latent=%f)",
v, params["lambda"].(float64), params["eta"].(float64), params["latent"].(float64))
}
import optuna
import json
import subprocess
import multiprocessing
def get_meta_path(trial_number: int):
return f"./data/optuna/ffm-meta-{trial_number}.json"
def objective(trial: optuna.Trial):
lmd = trial.suggest_loguniform("lambda", 1e-6, 1)
eta = trial.suggest_loguniform("eta", 1e-6, 1)
json_meta_path = get_meta_path(trial.number)
commands = [
"./ffm-train",
"-p", "./data/valid2.txt",
"--auto-stop", "--auto-stop-threshold", "3",
"-l", str(lmd),
"-r", str(eta),
"-k", "4",
"-t", str(500),
"--json-meta", json_meta_path,
"./data/train2.txt",
]
result = subprocess.run(
commands,
capture_output=True,
universal_newlines=True,
encoding='utf-8')
trial.set_user_attr("args", result.args)
best_iteration = None
best_va_loss = None
with open(json_meta_path) as f:
json_dict = json.load(f)
best_iteration = json_dict.get('best_iteration')
best_va_loss = json_dict.get('best_va_loss')
if best_iteration is None or best_va_loss is None:
raise ValueError("failed to open json meta")
trial.set_user_attr("best_iteration", best_iteration)
return best_va_loss
def main():
storage = optuna.storages.RDBStorage(
"sqlite:///db.sqlite3",
engine_kwargs={"pool_size": 1})
sampler = optuna.integration.SkoptSampler()
study = optuna.load_study(
study_name="dynalyst-ffm-gp",
storage=storage,
sampler=sampler)
study.optimize(
objective,
n_trials=256,
n_jobs=multiprocessing.cpu_count() - 1,
catch=())
print("best_trial", study.best_trial.number)
print("best_params", study.best_params)
print("best_value", study.best_value)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment