Skip to content

Instantly share code, notes, and snippets.

@lebrice
Created August 4, 2022 21:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lebrice/661d24d00f71fa033e9ca74c4a89e359 to your computer and use it in GitHub Desktop.
Save lebrice/661d24d00f71fa033e9ca74c4a89e359 to your computer and use it in GitHub Desktop.
Idea for a new use for typing.Unpack: Annotate the signature of a callable.
from __future__ import annotations
from pytorch_lightning import Trainer
from typing_extensions import Unpack, ParamSpec
from typing import Callable, TypedDict
# Option A: TypedDict
# --> lots of code duplication!
class TrainerConfig(TypedDict, total=False):
gpus: int | None
accelerator: str | None
strategy: str | None
devices: int | None
callbacks: list
max_epochs: int | None
limit_train_batches: int | None
limit_val_batches: int | None
profiler: str | None
...
def add_trainer_args(
**trainer_kwargs: Unpack[TrainerConfig],
):
"""Example of a function that receives a partial signature for the Trainer __init__."""
...
trainer = Trainer(**trainer_kwargs)
# Option B: ParamSpec
# --> Need to always have *args, as well as an unused argument to bind the type to Trainer:
_TrainerSig = ParamSpec("_TrainerSig")
def add_trainer_args(
_trainer_type: type[Callable[_TrainerSig, Trainer]] = Trainer,
*trainer_args: _TrainerSig.args,
**trainer_kwargs: _TrainerSig.kwargs,
):
"""
Adds arguments to the argument parser for the Trainer.
"""
trainer = _trainer_type(*trainer_args, **trainer_kwargs)
# Roughly Proposed solution:
def add_trainer_args(**trainer_kwargs: Unpack[Trainer]):
trainer = Trainer(**trainer_kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment