Skip to content

Instantly share code, notes, and snippets.

@c-bata
Last active March 22, 2023 02:16
Show Gist options
  • Save c-bata/98532a60609a8a5f9e1e4dd162d45886 to your computer and use it in GitHub Desktop.
Save c-bata/98532a60609a8a5f9e1e4dd162d45886 to your computer and use it in GitHub Desktop.
# $ docker run -d --rm --platform linux/amd64 -p 5432:5432 -e POSTGRES_USER=root -e POSTGRES_PASSWORD=root -e POSTGRES_DB=optuna --name optuna-postgres postgres:12.10
# $ docker run -it --rm --platform linux/amd64 --network host -v $(pwd):/usr/src python:3.10 bash
# # cd /usr/src
# # pip install -U setuptools pip psycopg2
# # pip install -e .
from __future__ import annotations
import math
import threading
import time
from sqlalchemy import event
from sqlalchemy.engine.base import Engine
import optuna
optuna.logging.set_verbosity(optuna.logging.ERROR)
storage_url = "postgresql+psycopg2://root:root@127.0.0.1/optuna"
storage = optuna.storages.RDBStorage(storage_url)
sql_queries_lock = threading.Lock()
sql_queries: dict[str, tuple[int, list[float]]] = {}
n_studies = 100
n_trials = 500
n_params = 10
class EngineProfiler:
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.query_start_time = time.perf_counter()
def register(self) -> None:
event.listen(self.engine, "before_cursor_execute", self.before_cursor_execute)
event.listen(self.engine, "after_cursor_execute", self.after_cursor_execute)
def before_cursor_execute( # type: ignore
self, conn, cursor, statement, parameters, context, executemany
) -> None:
self.query_start_time = time.perf_counter()
def after_cursor_execute( # type: ignore
self, conn, cursor, stmt, parameters, context, executemany
) -> None:
global sql_queries, sql_queries_lock
duration = time.perf_counter() - self.query_start_time
with sql_queries_lock:
registered = stmt in sql_queries
sql_queries[stmt] = (
sql_queries[stmt][0] + 1 if registered else 1,
sql_queries[stmt][1] + [duration] if registered else [duration],
)
def objective(trial: optuna.Trial) -> float:
return sum([
math.sin(trial.suggest_float('param-{}'.format(i), 0, math.pi * 2))
for i in range(n_params)
])
def main():
global sql_queries, sql_queries_lock
# Create trials
if len(storage.get_all_studies()) == 0:
for i in range(n_studies):
print(i)
study = optuna.create_study(storage=storage)
study.optimize(lambda trial: objective(trial), n_trials=n_trials, n_jobs=8)
# Profile study.optimize()
EngineProfiler(storage.engine).register()
start = time.time()
tmp_studies = []
for i in range(100):
tmp_study = optuna.create_study(storage=storage)
tmp_study.optimize(objective, n_trials=10)
tmp_studies.append(tmp_study)
elapsed = time.time() - start
print(f"Elapsed: {elapsed:.4f}s ({n_trials=} {n_params=})")
# Show profiler stats
summary = [
(stmt, count, f"{sum(durations):.4f}", sum(durations))
for stmt, (count, durations) in sql_queries.items()
]
sort_by_total = sorted(summary, key=lambda r: r[3], reverse=True)
print("Sort by Total:")
print("Total Time(s)\tQuery Count\tStatement")
for q in sort_by_total[:5]:
print(f"{q[2]}\t{q[1]}\t{q[0]}")
# Clean up
for s in tmp_studies:
optuna.delete_study(study_name=s.study_name, storage=storage)
if __name__ == '__main__':
main()
@c-bata
Copy link
Author

c-bata commented Mar 22, 2023

Before

# python profiler.py
Elapsed: 157.2042s (n_trials=500 n_params=10)

Sort by Total:
Total Time(s)	Query Count	Statement
43.5951	11000	SELECT trials.trial_id AS trials_trial_id
FROM trials
WHERE trials.study_id = %(study_id_1)s
40.9885	10000	SELECT trial_params.param_id AS trial_params_param_id, trial_params.trial_id AS trial_params_trial_id, trial_params.param_name AS trial_params_param_name, trial_params.param_value AS trial_params_param_value, trial_params.distribution_json AS trial_params_distribution_json
FROM trial_params JOIN trials ON trials.trial_id = trial_params.trial_id
WHERE trials.study_id = %(study_id_1)s AND trial_params.param_name = %(param_name_1)s
 LIMIT %(param_1)s
4.7336	21000	SELECT trials.trial_id AS trials_trial_id, trials.number AS trials_number, trials.study_id AS trials_study_id, trials.state AS trials_state, trials.datetime_start AS trials_datetime_start, trials.datetime_complete AS trials_datetime_complete
FROM trials
WHERE trials.trial_id = %(trial_id_1)s
4.5883	1000	SELECT count(trials.trial_id) AS count_1
FROM trials
WHERE trials.study_id = %(study_id_1)s AND trials.trial_id < %(trial_id_1)s
3.5560	11100	SELECT studies.study_id AS studies_study_id, studies.study_name AS studies_study_name
FROM studies
WHERE studies.study_id = %(study_id_1)s

After

# python profiler.py
Elapsed: 66.0051s (n_trials=500 n_params=10)

Sort by Total:
Total Time(s)	Query Count	Statement
4.3927	21000	SELECT trials.trial_id AS trials_trial_id, trials.number AS trials_number, trials.study_id AS trials_study_id, trials.state AS trials_state, trials.datetime_start AS trials_datetime_start, trials.datetime_complete AS trials_datetime_complete
FROM trials
WHERE trials.trial_id = %(trial_id_1)s
3.0978	10000	SELECT trial_params.param_id AS trial_params_param_id, trial_params.trial_id AS trial_params_trial_id, trial_params.param_name AS trial_params_param_name, trial_params.param_value AS trial_params_param_value, trial_params.distribution_json AS trial_params_distribution_json
FROM trial_params JOIN trials ON trials.trial_id = trial_params.trial_id
WHERE trials.study_id = %(study_id_1)s AND trial_params.param_name = %(param_name_1)s
 LIMIT %(param_1)s
2.9455	11100	SELECT studies.study_id AS studies_study_id, studies.study_name AS studies_study_name
FROM studies
WHERE studies.study_id = %(study_id_1)s
2.3931	10000	INSERT INTO trial_params (trial_id, param_name, param_value, distribution_json) VALUES (%(trial_id)s, %(param_name)s, %(param_value)s, %(distribution_json)s) RETURNING trial_params.param_id
1.8892	11000	SELECT trials.trial_id AS trials_trial_id, trials.number AS trials_number, trials.study_id AS trials_study_id, trials.state AS trials_state, trials.datetime_start AS trials_datetime_start, trials.datetime_complete AS trials_datetime_complete
FROM trials
WHERE trials.trial_id IN (NULL) AND (1 != 1) AND trials.study_id = %(study_id_1)s ORDER BY trials.trial_id

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment