Skip to content

Instantly share code, notes, and snippets.

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = ["Peter Yu <2057325+yukw777@users.noreply.github.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "~3.9"
@yukw777
yukw777 / hydra_compose_api_unit_test_ex.py
Created February 9, 2021 14:57
Hydra Compose API Unit Tests Example
@pytest.mark.parametrize("network_size", ["small", "big", "huge"])
def test_train_network_size(network_size):
with initialize(config_path="../leela_zero_pytorch/conf"):
cfg = compose(
config_name="config",
overrides=[
f"+network={network_size}",
"data.train_data_dir=tests/test-data",
"data.train_dataloader_conf.batch_size=2",
"data.val_data_dir=tests/test-data",
@yukw777
yukw777 / pl_metrics_ex.py
Created February 9, 2021 14:23
PyTorch Lightning Metrics Example
class NetworkLightningModule(..., pl.LightningModule):
def __init__(self, ...):
super().__init__(...)
self.save_hyperparameters()
# metrics
self.train_accuracy = pl.metrics.Accuracy()
self.val_accuracy = pl.metrics.Accuracy()
self.test_accuracy = pl.metrics.Accuracy()
@yukw777
yukw777 / pl_logging_ex.py
Last active February 9, 2021 14:22
PyTorch Lightning Logging Example
def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
planes, target_move, target_val = batch
pred_move, pred_val = self(planes)
mse_loss, cross_entropy_loss, loss = self.loss(
pred_move, pred_val, target_move, target_val
)
self.log("train_loss", loss, prog_bar=True)
self.log_dict(
{
"train_mse_loss": mse_loss,
@yukw777
yukw777 / new.py
Created August 27, 2020 16:17
PyTorch Lightning + Hydra Training Script Somparison
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from hydra.utils import instantiate
logger = logging.getLogger(__name__)
@yukw777
yukw777 / hydra_compose_api_unit_test_ex.py
Created August 27, 2020 15:43
Hydra Compose API Unit Tests Example
@pytest.mark.parametrize("network_size", ["small", "big", "huge"])
def test_train_network_size(monkeypatch, tmp_path, capsys, network_size):
with initialize(config_path="../leela_zero_pytorch/conf"):
cfg = compose(
config_name="config",
overrides=[
f"+network={network_size}",
"data.train_data_dir=tests/test-data",
"data.train_dataloader_conf.batch_size=2",
"data.val_data_dir=tests/test-data",
@yukw777
yukw777 / hydra_obj_inst_ex.py
Created August 27, 2020 15:41
Hydra Object Instantiation Example
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> Trainer:
logger.info(f"Training with the following config:\n{OmegaConf.to_yaml(cfg)}")
network = instantiate(cfg.network, cfg.train)
@yukw777
yukw777 / big.yaml
Last active August 27, 2020 15:40
Hydra Object Instantiation Configuration Example
# @package _group_
_target_: leela_zero_pytorch.network.NetworkLightningModule
network_conf:
residual_channels: 128
residual_layers: 6
@yukw777
yukw777 / hydra_package_directive_ex.yaml
Last active August 27, 2020 16:12
Hydra Package Directive Example
# @package _group_
_target_: leela_zero_pytorch.network.NetworkLightningModule
network_conf:
residual_channels: 32
residual_layers: 8
@yukw777
yukw777 / pl_data_module_ex.py
Created August 27, 2020 15:35
PyTorch Lightning Data Module Example
class DataModule(pl.LightningDataModule):
def __init__(
self,
train_data_dir: str,
val_data_dir: str,
test_data_dir: str,
train_dataloader_conf: Optional[DictConfig] = None,
val_dataloader_conf: Optional[DictConfig] = None,
test_dataloader_conf: Optional[DictConfig] = None,
):