Created
August 4, 2022 21:11
-
-
Save lebrice/661d24d00f71fa033e9ca74c4a89e359 to your computer and use it in GitHub Desktop.
Idea for a new use for typing.Unpack: Annotate the signature of a callable.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from pytorch_lightning import Trainer | |
from typing_extensions import Unpack, ParamSpec | |
from typing import Callable, TypedDict | |
# Option A: TypedDict | |
# --> lots of code duplication! | |
class TrainerConfig(TypedDict, total=False): | |
gpus: int | None | |
accelerator: str | None | |
strategy: str | None | |
devices: int | None | |
callbacks: list | |
max_epochs: int | None | |
limit_train_batches: int | None | |
limit_val_batches: int | None | |
profiler: str | None | |
... | |
def add_trainer_args( | |
**trainer_kwargs: Unpack[TrainerConfig], | |
): | |
"""Example of a function that receives a partial signature for the Trainer __init__.""" | |
... | |
trainer = Trainer(**trainer_kwargs) | |
# Option B: ParamSpec | |
# --> Need to always have *args, as well as an unused argument to bind the type to Trainer: | |
_TrainerSig = ParamSpec("_TrainerSig") | |
def add_trainer_args( | |
_trainer_type: type[Callable[_TrainerSig, Trainer]] = Trainer, | |
*trainer_args: _TrainerSig.args, | |
**trainer_kwargs: _TrainerSig.kwargs, | |
): | |
""" | |
Adds arguments to the argument parser for the Trainer. | |
""" | |
trainer = _trainer_type(*trainer_args, **trainer_kwargs) | |
# Roughly Proposed solution: | |
def add_trainer_args(**trainer_kwargs: Unpack[Trainer]): | |
trainer = Trainer(**trainer_kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment