Skip to content

Instantly share code, notes, and snippets.

@KilianFt
Created November 20, 2023 09:22
Show Gist options
  • Save KilianFt/005afbb811bf91e58354b6fb1cc2c6fb to your computer and use it in GitHub Desktop.
Save KilianFt/005afbb811bf91e58354b6fb1cc2c6fb to your computer and use it in GitHub Desktop.
wandb sweep with d3rlpy
from typing import Dict, Any
import torch
import wandb
import numpy as np
import d3rlpy
from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DQNConfig
from d3rlpy.metrics import TDErrorEvaluator, EnvironmentEvaluator
class WandbAdapter(d3rlpy.logging.LoggerAdapter):
def write_params(self, params: Dict[str, Any]) -> None:
# save dictionary as json file
# with open("params.json", "w") as f:
# f.write(json.dumps(params, default=default_json_encoder, indent=2))
pass
def before_write_metric(self, epoch: int, step: int) -> None:
pass
def write_metric(self, epoch: int, step: int, name: str, value: float) -> None:
# with open(f"{name}.csv", "a") as f:
# print(f"{epoch},{step},{value}", file=f)
wandb.run.log(
{name: value,},
step=step
)
def after_write_metric(self, epoch: int, step: int) -> None:
pass
def save_model(self, epoch: int, algo: Any) -> None:
# algo.save(f"model_{epoch}.d3")
pass
def close(self) -> None:
pass
class WandbAdapterFactory(d3rlpy.logging.LoggerAdapterFactory):
def create(self, experiment_name: str) -> d3rlpy.logging.FileAdapter:
return WandbAdapter()
def main():
run = wandb.init(project='cartpole', tags=['demo'])
config = wandb.config
dataset, env = get_cartpole()
# where to save the "normal" files
experiment_folder = 'logs/cartpole/'
logger_adapter = d3rlpy.logging.CombineAdapterFactory(
[
d3rlpy.logging.FileAdapterFactory(root_dir=experiment_folder),
WandbAdapterFactory(),
]
)
dqn = DQNConfig(batch_size=config.batch_size, learning_rate=config.lr).create(device="cpu")
dqn.build_with_dataset(dataset)
td_error_evaluator = TDErrorEvaluator(episodes=dataset.episodes)
env_evaluator = EnvironmentEvaluator(env)
hist = dqn.fit(
dataset,
n_steps=2000,
n_steps_per_epoch=1000,
evaluators={
'td_error': td_error_evaluator,
'environment': env_evaluator,
},
logger_adapter=logger_adapter
)
# you can log anything else with wandb.log({'name': value})
# or set a config parameter with wandb.config['param'] = value
# if you have any summary results
wandb.run.summary["overall_td"] = np.mean([h[1]["td_error"] for h in hist])
if __name__ == '__main__':
random_seed = 100
torch.manual_seed(random_seed)
sweep_configuration = {
"method": "bayes",
"name": "sweep",
"metric": {"goal": "maximize", "name": "environment"},
"parameters": {
"env_name": {"value": "cartpole"},
"batch_size": {"values": [32, 64]},
"lr": {"max": 0.001, "min": 0.0001},
},
}
sweep_id = wandb.sweep(sweep=sweep_configuration, project="cartpole")
wandb.agent(sweep_id, function=main, count=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment