Skip to content

Instantly share code, notes, and snippets.

@yukw777
Created August 27, 2020 16:17
Show Gist options
  • Save yukw777/fdc4f6d3cda338c8fd63f525c1a4daf8 to your computer and use it in GitHub Desktop.
Save yukw777/fdc4f6d3cda338c8fd63f525c1a4daf8 to your computer and use it in GitHub Desktop.
PyTorch Lightning + Hydra Training Script Somparison
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from hydra.utils import instantiate
logger = logging.getLogger(__name__)
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> Trainer:
logger.info(f"Training with the following config:\n{OmegaConf.to_yaml(cfg)}")
network = instantiate(cfg.network, cfg.train)
data = instantiate(cfg.data)
trainer_logger = instantiate(cfg.logger) if "logger" in cfg else True
trainer = Trainer(**cfg.pl_trainer, logger=trainer_logger)
trainer.fit(network, data)
if cfg.train.run_test:
trainer.test(datamodule=data)
return trainer
if __name__ == "__main__":
main()
import logging
import hydra
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from hydra.utils import instantiate
from leela_zero_pytorch.network import NetworkLightningModule
from leela_zero_pytorch.dataset import Dataset
logger = logging.getLogger(__name__)
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> Trainer:
logger.info(f"Training with the following config:\n{cfg.pretty()}")
module = NetworkLightningModule(cfg.network, cfg.train)
trainer_logger = instantiate(cfg.logger) if "logger" in cfg else True
trainer = Trainer(**cfg.pl_trainer, logger=trainer_logger)
trainer.fit(
module,
train_dataloader=DataLoader(
Dataset.from_data_dir(
hydra.utils.to_absolute_path(cfg.dataset.train.dir_path), transform=True
),
shuffle=True,
batch_size=cfg.dataset.train.batch_size,
num_workers=cfg.dataset.train.num_workers,
),
val_dataloaders=DataLoader(
Dataset.from_data_dir(
hydra.utils.to_absolute_path(cfg.dataset.val.dir_path)
),
batch_size=cfg.dataset.val.batch_size,
num_workers=cfg.dataset.val.num_workers,
),
)
if cfg.train.run_test:
trainer.test(
test_dataloaders=DataLoader(
Dataset.from_data_dir(
hydra.utils.to_absolute_path(cfg.dataset.test.dir_path)
),
batch_size=cfg.dataset.train.batch_size,
num_workers=cfg.dataset.test.num_workers,
)
)
return trainer
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment