Skip to content

Instantly share code, notes, and snippets.

@c-bata
Last active March 13, 2023 12:40
Show Gist options
  • Save c-bata/c08fb89a583adbcdc3eddcf8cf192c1a to your computer and use it in GitHub Desktop.
Save c-bata/c08fb89a583adbcdc3eddcf8cf192c1a to your computer and use it in GitHub Desktop.
# docker run -d --rm -p 3306:3306 -e MYSQL_USER=optuna -e MYSQL_DATABASE=optuna -e MYSQL_PASSWORD=password -e MYSQL_ALLOW_EMPTY_PASSWORD=yes --name optuna-mysql mysql:8.0
from __future__ import annotations
import math
import threading
import time
from sqlalchemy import event
from sqlalchemy.engine.base import Engine
import optuna
optuna.logging.set_verbosity(optuna.logging.ERROR)
storage_url = "mysql+pymysql://optuna:password@127.0.0.1:3306/optuna"
storage = optuna.storages.RDBStorage(storage_url)
sql_queries_lock = threading.Lock()
sql_queries: dict[str, tuple[int, list[float]]] = {}
n_studies = 50
n_trials = 100
n_params = 10
class EngineProfiler:
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.query_start_time = time.perf_counter()
def register(self) -> None:
event.listen(self.engine, "before_cursor_execute", self.before_cursor_execute)
event.listen(self.engine, "after_cursor_execute", self.after_cursor_execute)
def before_cursor_execute( # type: ignore
self, conn, cursor, statement, parameters, context, executemany
) -> None:
self.query_start_time = time.perf_counter()
def after_cursor_execute( # type: ignore
self, conn, cursor, stmt, parameters, context, executemany
) -> None:
global sql_queries, sql_queries_lock
duration = time.perf_counter() - self.query_start_time
with sql_queries_lock:
registered = stmt in sql_queries
sql_queries[stmt] = (
sql_queries[stmt][0] + 1 if registered else 1,
sql_queries[stmt][1] + [duration] if registered else [duration],
)
def objective(trial: optuna.Trial) -> float:
return sum([
math.sin(trial.suggest_float('param-{}'.format(i), 0, math.pi * 2))
for i in range(n_params)
])
def main():
global sql_queries, sql_queries_lock
# Create trials
if len(storage.get_all_studies()) == 0:
for i in range(n_studies):
study = optuna.create_study(storage=storage)
study.optimize(lambda trial: objective(trial), n_trials=n_trials, n_jobs=8)
# Profile storage.get_all_trials()
EngineProfiler(storage.engine).register()
start = time.time()
for i in range(100):
storage.get_all_trials(study_id=1)
elapsed = time.time() - start
print(f"Elapsed: {elapsed:.4f}s ({n_trials=} {n_params=})")
# Show profiler stats
summary = [
(stmt, count, f"{sum(durations):.4f}", sum(durations))
for stmt, (count, durations) in sql_queries.items()
]
sort_by_total = sorted(summary, key=lambda r: r[3], reverse=True)
print("")
print("Sort by Total:")
print("Total Time(s)\tQuery Count\tStatement")
for q in sort_by_total[:5]:
print(f"{q[2]}\t{q[1]}\t{q[0]}")
if __name__ == '__main__':
main()
@c-bata
Copy link
Author

c-bata commented Mar 13, 2023

SQLite3

Before:

$ python profile_all_trials.py
Elapsed: 23.0874s (n_trials=100 n_params=10)

Sort by Total:
Total Time(s)	Query Count	Statement
0.0883	500	SELECT studies.study_id AS studies_study_id, studies.study_name AS studies_study_name
FROM studies
WHERE studies.study_id = ?
0.0695	500	SELECT trial_intermediate_values.trial_id AS trial_intermediate_values_trial_id, trial_intermediate_values.trial_intermediate_value_id AS trial_intermediate_values_trial_intermediate_value_id, trial_intermediate_values.step AS trial_intermediate_values_step, trial_intermediate_values.intermediate_value AS trial_intermediate_values_intermediate_value, trial_intermediate_values.intermediate_value_type AS trial_intermediate_values_intermediate_value_type
FROM trial_intermediate_values
WHERE trial_intermediate_values.trial_id IN (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)

After:

$ sqlite3 db.sqlite3
SQLite version 3.39.5 2022-10-14 20:58:05
Enter ".help" for usage hints.
sqlite> .schema trials
CREATE TABLE trials (
	trial_id INTEGER NOT NULL,
	number INTEGER,
	study_id INTEGER,
	state VARCHAR(8) NOT NULL,
	datetime_start DATETIME,
	datetime_complete DATETIME,
	PRIMARY KEY (trial_id),
	FOREIGN KEY(study_id) REFERENCES studies (study_id)
);
CREATE INDEX ix_trials_study_id ON trials (study_id);
sqlite> DROP INDEX ix_trials_study_id;
sqlite> .schema trials
CREATE TABLE trials (
	trial_id INTEGER NOT NULL,
	number INTEGER,
	study_id INTEGER,
	state VARCHAR(8) NOT NULL,
	datetime_start DATETIME,
	datetime_complete DATETIME,
	PRIMARY KEY (trial_id),
	FOREIGN KEY(study_id) REFERENCES studies (study_id)
);
sqlite>
$ python profile_all_trials.py
Elapsed: 22.3219s (n_trials=100 n_params=10)

Sort by Total:
Total Time(s)	Query Count	Statement
0.0726	500	SELECT studies.study_id AS studies_study_id, studies.study_name AS studies_study_name
FROM studies
WHERE studies.study_id = ?
0.0525	500	SELECT trial_values.trial_id AS trial_values_trial_id, trial_values.trial_value_id AS trial_values_trial_value_id, trial_values.objective AS trial_values_objective, trial_values.value AS trial_values_value, trial_values.value_type AS trial_values_value_type
FROM trial_values
WHERE trial_values.trial_id IN (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment