Skip to content

Instantly share code, notes, and snippets.

@sappho192
Created June 13, 2024 07:55
Show Gist options
  • Save sappho192/eb12a5644f6bc46eaa4c11e022c6aabe to your computer and use it in GitHub Desktop.
Save sappho192/eb12a5644f6bc46eaa4c11e022c6aabe to your computer and use it in GitHub Desktop.
Optuna+Comet ML example
---
# Config for NeMo parameter optimization using Optuna and Comet ML.
optuna:
study_name: "<STUDY_NAME>"
db:
type: postgres
psql:
host: "<PSQL_HOST>"
port: 5432
table: "<OPTUNA_TABLE>"
id: "<PSQL_ID>"
pw: "<PSQL_PW>"
comet:
api_key: "<COMET_API_KEY>"
workspace: "<COMET_WORKSPACE>"
project_name: "<COMET_PROJECT_NAME>"
from comet_ml import Experiment
import optuna
from omegaconf import OmegaConf
from loguru import logger
class OptunaConfig:
optuna_study_name: str
optuna_db: str
psql_host: str
psql_port: int
psql_table: str
psql_id: str
psql_pw: str
def __init__(self, config_optuna: OmegaConf):
self.optuna_study_name = config_optuna.optuna.study_name
self.optuna_db_type = config_optuna.optuna.db.type
if self.optuna_db_type == 'postgres':
self.psql_host = config_optuna.optuna.db.psql.host
self.psql_port = config_optuna.optuna.db.psql.port
self.psql_table = config_optuna.optuna.db.psql.table
self.psql_id = config_optuna.optuna.db.psql.id
self.psql_pw = config_optuna.optuna.db.psql.pw
else:
logger.error(f'Unsupported DB: {self.optuna_db_type}')
exit(-1)
class CometConfig:
api_key: str
workspace: str
project_name: str
def __init__(self, config_optuna: OmegaConf):
self.api_key = config_optuna.comet.api_key
self.workspace = config_optuna.comet.workspace
self.project_name = config_optuna.comet.project_name
def run_experiment(config_optuna: OptunaConfig,
config_comet: CometConfig):
global experiment
experiment = Experiment(
api_key=config_comet.api_key,
project_name=config_comet.project_name,
workspace=config_comet.workspace
)
study = optuna.create_study(
direction='minimize',
study_name=config_optuna.optuna_study_name,
storage=f'postgresql://{config_optuna.psql_id}:{config_optuna.psql_pw}@{config_optuna.psql_host}:{config_optuna.psql_port}/{config_optuna.psql_table}',
load_if_exists=True
)
study.optimize(objective, n_trials=1)
experiment.end()
return study
def objective(trial: optuna.trial.Trial):
x = trial.suggest_float('x', -10, 10, step=0.01)
y = trial.suggest_float('y', -5, 5, step=0.1)
global experiment
experiment.log_parameter('x', x)
experiment.log_parameter('y', y)
result = (x - 2) ** 2 + (y - 3) ** 2
experiment.log_metric('result', result)
return result
def main():
config = OmegaConf.load('config.optuna.yaml')
config_optuna = OptunaConfig(config)
config_comet = CometConfig(config)
global experiment
for i in range(10):
experiment = run_experiment(config_optuna=config_optuna, config_comet=config_comet)
print(f'Best param: {experiment.best_params}') # E.g. {'x': 2.002108042}
print(f'Best value: {experiment.best_value}')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment