Skip to content

Instantly share code, notes, and snippets.

@rmorshea
Created February 13, 2023 03:03
Show Gist options
  • Save rmorshea/f6ca988f4ec109c571e1d1ac4f738578 to your computer and use it in GitHub Desktop.
Save rmorshea/f6ca988f4ec109c571e1d1ac4f738578 to your computer and use it in GitHub Desktop.
A way to define groups of sessions with a common set of cli parameters.
from __future__ import annotations
import argparse
import inspect
import functools
from typing import Any, Callable, TypeVar, get_type_hints
import nox
from nox.sessions import Session
# --------------------------------------------------------------------------------------
# Session Groups
# --------------------------------------------------------------------------------------
Func = TypeVar("Func", bound=Callable[..., None])
class SessionGroup:
def __init__(self, name: str = "", parent: SessionGroup | None = None):
self.name = _join_names(parent.name if parent else "", name)
self._parent = parent
self._parser = parent._parser if parent else argparse.ArgumentParser(self.name)
self._sessions: list[Session] = []
self._param_info: _ParamInfo = parent._param_info if parent else _ParamInfo()
def session(self, func: Func) -> Func:
session_name = _join_names(self.name, func.__name__)
hints = get_type_hints(func)
# delete first positional argument
del hints[list(hints)[0]]
# delete return annotation
del hints["return"]
sig = inspect.signature(func)
defaults: dict[str, Any] = {}
for p in sig.parameters.values():
if p.kind not in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
raise TypeError(
f"Parameter {p.name!r} of {func} is not a keyword argument"
)
defaults[p.name] = p.default
# add all the rest as cli parameters
for k, v in hints.items():
if self._param_info.check_and_add(session_name, k, v, defaults[k]):
if defaults[k] is inspect.Parameter.empty:
self._parser.add_argument(f"--{k}", type=v)
else:
self._parser.add_argument(f"--{k}", type=v, default=defaults[k])
name_parts = self.name.split("-")
session_tags = ["-".join(name_parts[:i]) for i in range(1, len(name_parts))]
@nox.session(name=_join_names(self.name, func.__name__), tags=session_tags)
@functools.wraps(func)
def wrapper(session: Session) -> None:
args, _ = self._parser.parse_known_args(session.posargs)
func(session, **{k: v for k, v in args.__dict__.items() if k in hints})
return func
def group(self, name: str) -> SessionGroup:
return SessionGroup(name, self)
class _ParamInfo:
def __init__(self):
self._param_owners: dict[str, set[Any]] = {}
self._param_types_and_defaults: dict[str, tuple[Any, Any]] = {}
def check_and_add(self, owner: Any, name: str, type: Any, default: Any) -> None:
"""Return if parameter already exists, add if not, and raise on conflict"""
if name not in self._param_types_and_defaults:
self._param_owners[name] = {owner}
self._param_types_and_defaults[name] = (type, default)
return True
if self._param_types_and_defaults[name] != (type, default):
existing = self._param_types_and_defaults[name]
raise TypeError(
f"Parameter {name!r} alread declared as a {existing[0]} with default "
f"{existing[1]!r} by {', '.join(map(repr, owner))}."
)
if owner not in self._param_owners:
self._param_owners[name].add(owner)
return False
def _join_names(*names: str) -> str:
return "-".join(n.replace("_", "-") for n in names if n)
# --------------------------------------------------------------------------------------
# Session Definitions
# --------------------------------------------------------------------------------------
check = SessionGroup("check")
check_python = check.group("python")
@check_python.session
def suite(session: Session, no_pytest_cov: bool = False) -> None:
print(no_pytest_cov)
@check.session
def thing(session: Session, no_pytest_cov: bool = False) -> None:
print(no_pytest_cov)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment