Last active
March 6, 2023 05:02
-
-
Save Microsheep/11edda9dee7c1ba0c099709eb7f8bea7 to your computer and use it in GitHub Desktop.
W&B config for dataclasses (Python 3.7+)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
from dataclasses import dataclass, field, asdict | |
from typing import Callable, Optional, Dict, Any | |
import torch | |
def transform_dict(config_dict: Dict, expand: bool = True): | |
""" | |
General function to transform any dictionary into wandb config acceptable format | |
(This is mostly due to datatypes that are not able to fit into YAML format which makes wandb angry) | |
The expand argument is used to expand iterables into dictionaries so that these configs can be used when compare across runs | |
""" | |
ret: Dict[str, Any] = {} | |
for k, v in config_dict.items(): | |
if v is None or isinstance(v, (int, float, str)): | |
ret[k] = v | |
elif isinstance(v, (list, tuple, set)): | |
# Need to check if item in iterable is YAML-friendly | |
t = transform_dict(dict(enumerate(v)), expand) | |
# Transform back to iterable if expand is False | |
ret[k] = t if expand else [t[i] for i in range(len(v))] | |
elif isinstance(v, dict): | |
ret[k] = transform_dict(v, expand) | |
else: | |
# Transform to YAML-friendly (str) format | |
# Need to handle both Classes, Callables, Object Instances | |
# Custom Classes might not have great __repr__ so __name__ might be better in these cases | |
vname = v.__name__ if hasattr(v, '__name__') else v.__class__.__name__ | |
ret[k] = f"{v.__module__}:{vname}" | |
return ret | |
def dfac_cur_time(): | |
return datetime.now().strftime("%Y%m%d-%H%M%S") | |
@dataclass | |
class ExperimentConfig: | |
# GPU Setting | |
gpu_device_id: str = "1" | |
# Random Seed: Set to None to create new Seed | |
random_seed: Optional[int] = None | |
# Logging Related | |
tensorboard_log_root: str = "/tmp/XXX/tb/" | |
wandb_dir: str = "/tmp/XXX/wandb/" | |
cur_time: str = field(default_factory=dfac_cur_time) | |
# WandB setting | |
wandb_repo: str = "XXX" | |
wandb_project: str = "XXX" | |
wandb_group: str = "XXX" | |
# Training Settings | |
batch_size: int = 64 | |
num_epochs: int = 10 | |
preprocess_funcs: Optional[Callable] = None | |
dataset_sampler: Optional[torch.utils.data.sampler.Sampler] = None | |
model: Optional[torch.nn.Module] = None | |
optimizer: torch.optim.Optimizer = torch.optim.Adam | |
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None # pylint: disable=protected-access | |
def to_dict(self, expand: bool = True): | |
return transform_dict(asdict(self), expand) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is only meant to show how it will work in our usecase | |
from experiment import ExperimentConfig, run_experiment | |
if __name__ == "__main__": | |
# Setup Experiment Config | |
config = ExperimentConfig() | |
# Decide Model and structure settings | |
config.model = CNNModel | |
config.model_args = { | |
"model_structure": [(128, 7), (256, 5), (128, 3)] | |
} | |
config.preprocess_funcs = mypreprocess # Callable | |
config.dataset_sampler = mydataset_sampler(alpha=0.1) # Object | |
# Setup WandB | |
wandb.init( | |
entity=config.wandb_repo, project=config.wandb_project, | |
name=config.cur_time, group=config.wandb_group, | |
dir=config.wandb_dir, sync_tensorboard=True, config=config.to_dict(), | |
) | |
# Run Experiment | |
training_history, test_report = run_experiment(config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment