Skip to content

Instantly share code, notes, and snippets.

@yukw777
Created February 9, 2021 14:57
Show Gist options
  • Save yukw777/10bf6a392db51108aa03eec6c9b7a880 to your computer and use it in GitHub Desktop.
Save yukw777/10bf6a392db51108aa03eec6c9b7a880 to your computer and use it in GitHub Desktop.
Hydra Compose API Unit Tests Example
@pytest.mark.parametrize("network_size", ["small", "big", "huge"])
def test_train_network_size(network_size):
with initialize(config_path="../leela_zero_pytorch/conf"):
cfg = compose(
config_name="config",
overrides=[
f"+network={network_size}",
"data.train_data_dir=tests/test-data",
"data.train_dataloader_conf.batch_size=2",
"data.val_data_dir=tests/test-data",
"data.val_dataloader_conf.batch_size=2",
"data.test_data_dir=tests/test-data",
"data.test_dataloader_conf.batch_size=2",
"+pl_trainer.fast_dev_run=true",
],
)
train_main(cfg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment