Created
February 13, 2023 03:03
-
-
Save rmorshea/f6ca988f4ec109c571e1d1ac4f738578 to your computer and use it in GitHub Desktop.
A way to define groups of sessions with a common set of cli parameters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import argparse | |
import inspect | |
import functools | |
from typing import Any, Callable, TypeVar, get_type_hints | |
import nox | |
from nox.sessions import Session | |
# -------------------------------------------------------------------------------------- | |
# Session Groups | |
# -------------------------------------------------------------------------------------- | |
Func = TypeVar("Func", bound=Callable[..., None]) | |
class SessionGroup: | |
def __init__(self, name: str = "", parent: SessionGroup | None = None): | |
self.name = _join_names(parent.name if parent else "", name) | |
self._parent = parent | |
self._parser = parent._parser if parent else argparse.ArgumentParser(self.name) | |
self._sessions: list[Session] = [] | |
self._param_info: _ParamInfo = parent._param_info if parent else _ParamInfo() | |
def session(self, func: Func) -> Func: | |
session_name = _join_names(self.name, func.__name__) | |
hints = get_type_hints(func) | |
# delete first positional argument | |
del hints[list(hints)[0]] | |
# delete return annotation | |
del hints["return"] | |
sig = inspect.signature(func) | |
defaults: dict[str, Any] = {} | |
for p in sig.parameters.values(): | |
if p.kind not in ( | |
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
inspect.Parameter.KEYWORD_ONLY, | |
): | |
raise TypeError( | |
f"Parameter {p.name!r} of {func} is not a keyword argument" | |
) | |
defaults[p.name] = p.default | |
# add all the rest as cli parameters | |
for k, v in hints.items(): | |
if self._param_info.check_and_add(session_name, k, v, defaults[k]): | |
if defaults[k] is inspect.Parameter.empty: | |
self._parser.add_argument(f"--{k}", type=v) | |
else: | |
self._parser.add_argument(f"--{k}", type=v, default=defaults[k]) | |
name_parts = self.name.split("-") | |
session_tags = ["-".join(name_parts[:i]) for i in range(1, len(name_parts))] | |
@nox.session(name=_join_names(self.name, func.__name__), tags=session_tags) | |
@functools.wraps(func) | |
def wrapper(session: Session) -> None: | |
args, _ = self._parser.parse_known_args(session.posargs) | |
func(session, **{k: v for k, v in args.__dict__.items() if k in hints}) | |
return func | |
def group(self, name: str) -> SessionGroup: | |
return SessionGroup(name, self) | |
class _ParamInfo: | |
def __init__(self): | |
self._param_owners: dict[str, set[Any]] = {} | |
self._param_types_and_defaults: dict[str, tuple[Any, Any]] = {} | |
def check_and_add(self, owner: Any, name: str, type: Any, default: Any) -> None: | |
"""Return if parameter already exists, add if not, and raise on conflict""" | |
if name not in self._param_types_and_defaults: | |
self._param_owners[name] = {owner} | |
self._param_types_and_defaults[name] = (type, default) | |
return True | |
if self._param_types_and_defaults[name] != (type, default): | |
existing = self._param_types_and_defaults[name] | |
raise TypeError( | |
f"Parameter {name!r} alread declared as a {existing[0]} with default " | |
f"{existing[1]!r} by {', '.join(map(repr, owner))}." | |
) | |
if owner not in self._param_owners: | |
self._param_owners[name].add(owner) | |
return False | |
def _join_names(*names: str) -> str: | |
return "-".join(n.replace("_", "-") for n in names if n) | |
# -------------------------------------------------------------------------------------- | |
# Session Definitions | |
# -------------------------------------------------------------------------------------- | |
check = SessionGroup("check") | |
check_python = check.group("python") | |
@check_python.session | |
def suite(session: Session, no_pytest_cov: bool = False) -> None: | |
print(no_pytest_cov) | |
@check.session | |
def thing(session: Session, no_pytest_cov: bool = False) -> None: | |
print(no_pytest_cov) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment