Created
June 13, 2024 07:55
-
-
Save sappho192/eb12a5644f6bc46eaa4c11e022c6aabe to your computer and use it in GitHub Desktop.
Optuna+Comet ML example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- | |
# Config for NeMo parameter optimization using Optuna and Comet ML. | |
optuna: | |
study_name: "<STUDY_NAME>" | |
db: | |
type: postgres | |
psql: | |
host: "<PSQL_HOST>" | |
port: 5432 | |
table: "<OPTUNA_TABLE>" | |
id: "<PSQL_ID>" | |
pw: "<PSQL_PW>" | |
comet: | |
api_key: "<COMET_API_KEY>" | |
workspace: "<COMET_WORKSPACE>" | |
project_name: "<COMET_PROJECT_NAME>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from comet_ml import Experiment | |
import optuna | |
from omegaconf import OmegaConf | |
from loguru import logger | |
class OptunaConfig: | |
optuna_study_name: str | |
optuna_db: str | |
psql_host: str | |
psql_port: int | |
psql_table: str | |
psql_id: str | |
psql_pw: str | |
def __init__(self, config_optuna: OmegaConf): | |
self.optuna_study_name = config_optuna.optuna.study_name | |
self.optuna_db_type = config_optuna.optuna.db.type | |
if self.optuna_db_type == 'postgres': | |
self.psql_host = config_optuna.optuna.db.psql.host | |
self.psql_port = config_optuna.optuna.db.psql.port | |
self.psql_table = config_optuna.optuna.db.psql.table | |
self.psql_id = config_optuna.optuna.db.psql.id | |
self.psql_pw = config_optuna.optuna.db.psql.pw | |
else: | |
logger.error(f'Unsupported DB: {self.optuna_db_type}') | |
exit(-1) | |
class CometConfig: | |
api_key: str | |
workspace: str | |
project_name: str | |
def __init__(self, config_optuna: OmegaConf): | |
self.api_key = config_optuna.comet.api_key | |
self.workspace = config_optuna.comet.workspace | |
self.project_name = config_optuna.comet.project_name | |
def run_experiment(config_optuna: OptunaConfig, | |
config_comet: CometConfig): | |
global experiment | |
experiment = Experiment( | |
api_key=config_comet.api_key, | |
project_name=config_comet.project_name, | |
workspace=config_comet.workspace | |
) | |
study = optuna.create_study( | |
direction='minimize', | |
study_name=config_optuna.optuna_study_name, | |
storage=f'postgresql://{config_optuna.psql_id}:{config_optuna.psql_pw}@{config_optuna.psql_host}:{config_optuna.psql_port}/{config_optuna.psql_table}', | |
load_if_exists=True | |
) | |
study.optimize(objective, n_trials=1) | |
experiment.end() | |
return study | |
def objective(trial: optuna.trial.Trial): | |
x = trial.suggest_float('x', -10, 10, step=0.01) | |
y = trial.suggest_float('y', -5, 5, step=0.1) | |
global experiment | |
experiment.log_parameter('x', x) | |
experiment.log_parameter('y', y) | |
result = (x - 2) ** 2 + (y - 3) ** 2 | |
experiment.log_metric('result', result) | |
return result | |
def main(): | |
config = OmegaConf.load('config.optuna.yaml') | |
config_optuna = OptunaConfig(config) | |
config_comet = CometConfig(config) | |
global experiment | |
for i in range(10): | |
experiment = run_experiment(config_optuna=config_optuna, config_comet=config_comet) | |
print(f'Best param: {experiment.best_params}') # E.g. {'x': 2.002108042} | |
print(f'Best value: {experiment.best_value}') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment