Skip to content

Instantly share code, notes, and snippets.

@omry
Created September 17, 2020 22:32
Show Gist options
  • Save omry/4f970ef7041732ba923d03f9fff33757 to your computer and use it in GitHub Desktop.
Save omry/4f970ef7041732ba923d03f9fff33757 to your computer and use it in GitHub Desktop.
Recursive instantiation usage prototype
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass, field
from typing import Any, List
from omegaconf import MISSING, II, OmegaConf
import hydra
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
# library code
class Optimizer:
def __init__(self, lr: float) -> None:
self.lr = lr
# demoing a case of two different optimizer implementations
class Adam(Optimizer):
def __init__(self, lr: float, beta: float) -> None:
self.lr = lr
self.beta = float
class SGD(Optimizer):
def __init__(self, lr: float) -> None:
self.lr = lr
class Dataset:
def __init__(self, path: str, batch_size: int) -> None:
self.path = path
self.batch_size = batch_size
class Trainer:
def __init__(
self,
optimizer: Optimizer,
dataset: Dataset,
batch_size: int,
) -> None:
# currently those are config objects, not the real objects.
# recursive instantiation will fix it
print("provided optimizer :", optimizer)
print("provided dataset :", dataset)
self.optimizer = optimizer
self.dataset = dataset
self.batch_size = batch_size
# config hierarchy (will be possible to code-gen this in the future)
# Since recursive instantiating is not really supported now I can't commit.
@dataclass
class OptimizerConf:
_target_: str = "my_app.Optimizer"
lr: float = MISSING
@dataclass
class AdamConf(OptimizerConf):
_target_: str = "my_app.Adam"
lr: float = MISSING
beta: float = MISSING
@dataclass
class SGDConf(OptimizerConf):
_target_: str = "my_app.SGD"
lr: float = MISSING
@dataclass
class DatasetConf:
_target_: str = "my_app.Dataset"
path: str = MISSING
batch_size: int = MISSING
@dataclass
class TrainerConf:
_target_: str = "my_app.Trainer"
batch_size: int = MISSING
# not populated, we will choose the right optimizer with config composition
optimizer: OptimizerConf = MISSING
# if there is only one option we can just inline it here.
dataset: DatasetConf = DatasetConf()
@dataclass
class Config:
trainer: TrainerConf = TrainerConf()
defaults: List[Any] = field(
default_factory=lambda: [
# by default, compose adam
{"optimizer": "adam"},
# populate the rest from user_config, as an example this will be a yaml
"user_config",
]
)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
cs.store(group="optimizer", name="adam", node=AdamConf, package="trainer.optimizer")
cs.store(group="optimizer", name="sgd", node=SGDConf, package="trainer.optimizer")
@hydra.main(config_name="config")
def my_app(cfg: Config) -> None:
print(OmegaConf.to_yaml(cfg))
# once recursive instantiation will be supported, optimizer and dataset would be the actual objects
# currently they are the config node
trainer = instantiate(cfg.trainer)
if __name__ == "__main__":
my_app()
# @package _global_
trainer:
batch_size: 32
# in a more interesting scenario we will also compose the optimizer from multiple yaml files
optimizer:
lr: 0.1
beta: 0.9
dataset:
path: /foo/bar
batch_size: ${trainer.batch_size}
@omry
Copy link
Author

omry commented Sep 17, 2020

Example output:

trainer:
  _target_: my_app.Trainer
  batch_size: 32
  optimizer:
    _target_: my_app.Adam
    lr: 0.1
    beta: 0.9
  dataset:
    _target_: my_app.Dataset
    path: /foo/bar
    batch_size: ${trainer.batch_size}

provided optimizer : {'_target_': 'my_app.Adam', 'lr': 0.1, 'beta': 0.9}
provided dataset : {'_target_': 'my_app.Dataset', 'path': '/foo/bar', 'batch_size': '${trainer.batch_size}'}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment