Skip to content

Instantly share code, notes, and snippets.

@sappho192
Created June 7, 2024 07:19
Show Gist options
  • Save sappho192/91035698e9b6b112080973af14e7d1ce to your computer and use it in GitHub Desktop.
Save sappho192/91035698e9b6b112080973af14e7d1ce to your computer and use it in GitHub Desktop.
Optuna+Comet example
from comet_ml import Experiment
import optuna
comet_api_key = '<API_KEY>'
psql_id = "<PSQL_ID>"
psql_pw = "<PSQL_PW>"
psql_host = "<PSQL_HOST>"
psql_port = 5432
psql_table = "optuna"
project_name = "<PROJECT_NAME>"
workspace = "<WORKSPACE_NAME>"
def run_experiment():
experiment = Experiment(
api_key=comet_api_key,
project_name=project_name,
workspace=workspace
)
def objective(trial):
x = trial.suggest_discrete_uniform('x', -10, 10, 0.01)
y = trial.suggest_discrete_uniform('y', -5, 5, 0.1)
experiment.log_parameter('x', x)
experiment.log_parameter('y', y)
result = (x - 2) ** 2 + (y - 3) ** 2
experiment.log_metric('result', result)
return result
study = optuna.create_study(
direction='minimize',
study_name=project_name,
storage=f'postgresql://{psql_id}:{psql_pw}@{psql_host}:{psql_port}/{psql_table}',
load_if_exists=True
)
study.optimize(objective, n_trials=1)
experiment.end()
return study
for i in range(100):
experiment = run_experiment()
print(f'Best param: {experiment.best_params}') # E.g. {'x': 2.002108042}
print(f'Best value: {experiment.best_value}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment