Created
November 20, 2023 09:22
-
-
Save KilianFt/005afbb811bf91e58354b6fb1cc2c6fb to your computer and use it in GitHub Desktop.
wandb sweep with d3rlpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Any | |
import torch | |
import wandb | |
import numpy as np | |
import d3rlpy | |
from d3rlpy.datasets import get_cartpole | |
from d3rlpy.algos import DQNConfig | |
from d3rlpy.metrics import TDErrorEvaluator, EnvironmentEvaluator | |
class WandbAdapter(d3rlpy.logging.LoggerAdapter): | |
def write_params(self, params: Dict[str, Any]) -> None: | |
# save dictionary as json file | |
# with open("params.json", "w") as f: | |
# f.write(json.dumps(params, default=default_json_encoder, indent=2)) | |
pass | |
def before_write_metric(self, epoch: int, step: int) -> None: | |
pass | |
def write_metric(self, epoch: int, step: int, name: str, value: float) -> None: | |
# with open(f"{name}.csv", "a") as f: | |
# print(f"{epoch},{step},{value}", file=f) | |
wandb.run.log( | |
{name: value,}, | |
step=step | |
) | |
def after_write_metric(self, epoch: int, step: int) -> None: | |
pass | |
def save_model(self, epoch: int, algo: Any) -> None: | |
# algo.save(f"model_{epoch}.d3") | |
pass | |
def close(self) -> None: | |
pass | |
class WandbAdapterFactory(d3rlpy.logging.LoggerAdapterFactory): | |
def create(self, experiment_name: str) -> d3rlpy.logging.FileAdapter: | |
return WandbAdapter() | |
def main(): | |
run = wandb.init(project='cartpole', tags=['demo']) | |
config = wandb.config | |
dataset, env = get_cartpole() | |
# where to save the "normal" files | |
experiment_folder = 'logs/cartpole/' | |
logger_adapter = d3rlpy.logging.CombineAdapterFactory( | |
[ | |
d3rlpy.logging.FileAdapterFactory(root_dir=experiment_folder), | |
WandbAdapterFactory(), | |
] | |
) | |
dqn = DQNConfig(batch_size=config.batch_size, learning_rate=config.lr).create(device="cpu") | |
dqn.build_with_dataset(dataset) | |
td_error_evaluator = TDErrorEvaluator(episodes=dataset.episodes) | |
env_evaluator = EnvironmentEvaluator(env) | |
hist = dqn.fit( | |
dataset, | |
n_steps=2000, | |
n_steps_per_epoch=1000, | |
evaluators={ | |
'td_error': td_error_evaluator, | |
'environment': env_evaluator, | |
}, | |
logger_adapter=logger_adapter | |
) | |
# you can log anything else with wandb.log({'name': value}) | |
# or set a config parameter with wandb.config['param'] = value | |
# if you have any summary results | |
wandb.run.summary["overall_td"] = np.mean([h[1]["td_error"] for h in hist]) | |
if __name__ == '__main__': | |
random_seed = 100 | |
torch.manual_seed(random_seed) | |
sweep_configuration = { | |
"method": "bayes", | |
"name": "sweep", | |
"metric": {"goal": "maximize", "name": "environment"}, | |
"parameters": { | |
"env_name": {"value": "cartpole"}, | |
"batch_size": {"values": [32, 64]}, | |
"lr": {"max": 0.001, "min": 0.0001}, | |
}, | |
} | |
sweep_id = wandb.sweep(sweep=sweep_configuration, project="cartpole") | |
wandb.agent(sweep_id, function=main, count=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment