Skip to content

Instantly share code, notes, and snippets.

@Microsheep
Last active March 6, 2023 05:02
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Microsheep/11edda9dee7c1ba0c099709eb7f8bea7 to your computer and use it in GitHub Desktop.
Save Microsheep/11edda9dee7c1ba0c099709eb7f8bea7 to your computer and use it in GitHub Desktop.
W&B config for dataclasses (Python 3.7+)
from datetime import datetime
from dataclasses import dataclass, field, asdict
from typing import Callable, Optional, Dict, Any
import torch
def transform_dict(config_dict: Dict, expand: bool = True):
"""
General function to transform any dictionary into wandb config acceptable format
(This is mostly due to datatypes that are not able to fit into YAML format which makes wandb angry)
The expand argument is used to expand iterables into dictionaries so that these configs can be used when compare across runs
"""
ret: Dict[str, Any] = {}
for k, v in config_dict.items():
if v is None or isinstance(v, (int, float, str)):
ret[k] = v
elif isinstance(v, (list, tuple, set)):
# Need to check if item in iterable is YAML-friendly
t = transform_dict(dict(enumerate(v)), expand)
# Transform back to iterable if expand is False
ret[k] = t if expand else [t[i] for i in range(len(v))]
elif isinstance(v, dict):
ret[k] = transform_dict(v, expand)
else:
# Transform to YAML-friendly (str) format
# Need to handle both Classes, Callables, Object Instances
# Custom Classes might not have great __repr__ so __name__ might be better in these cases
vname = v.__name__ if hasattr(v, '__name__') else v.__class__.__name__
ret[k] = f"{v.__module__}:{vname}"
return ret
def dfac_cur_time():
return datetime.now().strftime("%Y%m%d-%H%M%S")
@dataclass
class ExperimentConfig:
# GPU Setting
gpu_device_id: str = "1"
# Random Seed: Set to None to create new Seed
random_seed: Optional[int] = None
# Logging Related
tensorboard_log_root: str = "/tmp/XXX/tb/"
wandb_dir: str = "/tmp/XXX/wandb/"
cur_time: str = field(default_factory=dfac_cur_time)
# WandB setting
wandb_repo: str = "XXX"
wandb_project: str = "XXX"
wandb_group: str = "XXX"
# Training Settings
batch_size: int = 64
num_epochs: int = 10
preprocess_funcs: Optional[Callable] = None
dataset_sampler: Optional[torch.utils.data.sampler.Sampler] = None
model: Optional[torch.nn.Module] = None
optimizer: torch.optim.Optimizer = torch.optim.Adam
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None # pylint: disable=protected-access
def to_dict(self, expand: bool = True):
return transform_dict(asdict(self), expand)
# This is only meant to show how it will work in our usecase
from experiment import ExperimentConfig, run_experiment
if __name__ == "__main__":
# Setup Experiment Config
config = ExperimentConfig()
# Decide Model and structure settings
config.model = CNNModel
config.model_args = {
"model_structure": [(128, 7), (256, 5), (128, 3)]
}
config.preprocess_funcs = mypreprocess # Callable
config.dataset_sampler = mydataset_sampler(alpha=0.1) # Object
# Setup WandB
wandb.init(
entity=config.wandb_repo, project=config.wandb_project,
name=config.cur_time, group=config.wandb_group,
dir=config.wandb_dir, sync_tensorboard=True, config=config.to_dict(),
)
# Run Experiment
training_history, test_report = run_experiment(config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment