Last active
September 5, 2023 02:36
-
-
Save Aldlevine/898a1d04bba27bd7c89371462f2e0132 to your computer and use it in GitHub Desktop.
Python MultiDispatch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
from multi_dispatch import MultiDispatch | |
multi = MultiDispatch().overload | |
@multi | |
def func() -> tuple[str, str]: | |
return "EMPTY", "" | |
@multi | |
def func(__a: int) -> tuple[str, int]: | |
return "__a: int", __a + 10 | |
@multi | |
def func(__a: str) -> tuple[str, str]: | |
return "__a: str", __a * 2 | |
@multi | |
def func(__a: str, *__extra: None) -> tuple[str, str]: | |
return "__a: str, *__extra: None", f"{__a}, {__extra}" | |
@multi | |
def func(__a: str, *__extra: Any) -> tuple[str, str]: | |
return "__a: str, *__extra: Any", f"{__a}, {__extra}" | |
@multi | |
def func(__a: str, **kwargs: Any) -> tuple[str, dict[str, Any]]: | |
return "__a: str, **kwargs: Any", {"__a": __a, **kwargs} | |
@multi | |
def func(__a: int, *, b: bool) -> tuple[str, dict[str, Any]]: | |
return "__a: str, b: bool", {"__a": __a, "b": b} | |
# fmt: off | |
print(func()) # >> ('EMPTY', '') | |
print(func(10)) # >> ('__a: int', 20) | |
print(func("Hello world. ")) # >> ('__a: str', 'Hello world. Hello world. ') | |
print(func("Hello", world=True)) # >> ('__a: str, **kwargs: Any', {'__a': 'Hello', 'world': True}) | |
print(func("a", "b", "c", "d")) # >> ('__a: str, *__extra: Any', "a, ('b', 'c', 'd')") | |
print(func("a", None, None, None)) # >> ('__a: str, *__extra: None', "a, (None, None, None)") | |
print(func(1, b=True)) # >> ('__a: str, *__extra: None', "a, (None, None, None)") | |
# func({"this": "fails"}) # >> TypeError: MultiDispatch func() arguments do not match any overload | |
# fmt: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from multi_dispatch import MultiDispatch | |
multi = MultiDispatch().overload | |
class Greeter: | |
@multi | |
def greet(self) -> str: | |
return "Hello" | |
@multi | |
def greet(self, name: str) -> str: | |
return f"{self.greet()}, {name}." | |
@multi | |
def greet(self, name: str, age: float) -> str: | |
return f"{self.greet(name)} I hear you are {age} years old?" | |
inst = Greeter() | |
# fmt: off | |
print(inst.greet()) # >> Hello | |
print(inst.greet("Danley")) # >> Hello, Danley. | |
print(inst.greet("Staniel", 175)) # >> Hello, Staniel. I hear you are 175 years old? | |
print(inst.greet(name="Queztalcoatl", age=float("-inf"))) # >> Hello, Queztalcoatl. I hear you are -inf years old? | |
# inst.greet(1, 2, 3, 4) # >> TypeError: MultiDispatch Greeter.greet() arguments do not match any overload | |
# fmt: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps | |
from inspect import Parameter, signature | |
from typing import TYPE_CHECKING, Any, Callable, cast | |
__all__ = ["MultiDispatch"] | |
class MultiDispatch: | |
__overload_registry: dict[str, list[Callable[..., Any]]] | |
__dispatch_registry: dict[str, Callable[..., Any]] | |
def __init__(self) -> None: | |
self.__overload_registry = {} | |
self.__dispatch_registry = {} | |
def __key_name(self, __fn: Callable[..., Any]) -> str: | |
return f"{__fn.__module__}.{__fn.__qualname__}" | |
def _is_assignable(self, value: Any, typ: Any) -> bool: | |
"""Override to provide custom assignability criteria | |
Args: | |
value: an argument to check for assignability | |
typ: a type to check against | |
Returns: | |
True if assignable | |
""" | |
# Any or no type should match anything | |
if typ is Any or typ is Parameter.empty: | |
return True | |
# Because None is a valid type parameter, but isn't a class | |
if typ is None: | |
return value is None | |
# Because ints should be assignable to floats | |
if typ is float and isinstance(value, int): | |
return True | |
# instances should be assignable to classes they're instances of | |
try: | |
if isinstance(value, typ): | |
return True | |
except: | |
... | |
return False | |
def __call__(self, __fn: Callable[..., Any], *__args: Any, **__kwargs: Any) -> Any: | |
"""Iterates through overloads registered to `__fn` and attempts to dispatch. | |
Args: | |
__fn: The function to dispatch for. | |
Raises: | |
`TypeError`: No valid overload | |
Returns: | |
The return value of the dispatched overload. | |
""" | |
for fn in self.__overload_registry[self.__key_name(__fn)]: | |
sig = signature(fn) | |
try: | |
bound_args = sig.bind(*__args, **__kwargs) | |
# type check args | |
params_list = list(sig.parameters.values()) | |
param_index = 0 | |
for arg in __args: | |
if param_index >= len(params_list): | |
raise TypeError() | |
param = params_list[param_index] | |
if param.kind in (param.KEYWORD_ONLY, param.VAR_KEYWORD): | |
raise TypeError() | |
if not self._is_assignable(arg, param.annotation): | |
raise TypeError() | |
if not param.kind == param.VAR_POSITIONAL: | |
param_index += 1 | |
# do I have **kwargs? | |
has_var_kwarg: bool = False | |
var_kwarg_type: type | None = None | |
for param in params_list: | |
if param.kind == param.VAR_KEYWORD: | |
has_var_kwarg = True | |
var_kwarg_type = param.annotation | |
# type check kwargs | |
params = sig.parameters | |
for kw, arg in __kwargs.items(): | |
if kw in params: | |
if not self._is_assignable(arg, params[kw].annotation): | |
raise TypeError() | |
continue | |
if has_var_kwarg: | |
if not self._is_assignable(arg, var_kwarg_type): | |
raise TypeError() | |
continue | |
raise TypeError() | |
# success! | |
return fn(*bound_args.args, **bound_args.kwargs) | |
except TypeError: | |
... | |
raise TypeError( | |
f"MultiDispatch {__fn.__qualname__}() arguments do not match any overload" | |
) | |
def register[**TT, T](self, fn: Callable[TT, T]) -> Callable[TT, T]: | |
"""Register an overload with the dispatcher | |
Args: | |
fn: The overload to register | |
Returns: | |
The dispatcher for `fn` | |
""" | |
key_name = self.__key_name(fn) | |
self.__overload_registry.setdefault(key_name, []).append(fn) | |
result = self.__dispatch_registry.get(key_name) | |
if result is None: | |
@wraps(fn) | |
def __dispatch(*args: TT.args, **kwargs: TT.kwargs) -> T: | |
return self(fn, *args, **kwargs) | |
result = self.__dispatch_registry[key_name] = __dispatch | |
return cast(Callable[TT, Any], result) | |
@property | |
def overload(self): | |
"""The overload decorator. When type checking this is `typing.overload`. | |
At runtime this is `self.register`. This allows type checkers to treat it | |
as the standard `typing.overload`. | |
It's not perfect as declarations will complain | |
"marked as overload but includes implementation" and | |
"marked as overload but no implementation provided". | |
It works correctly at runtime, and call site type checking | |
(at least with pyright/pylance) works. | |
Returns: | |
The overload decorator. | |
Intentionally untyped to force type checkers to infer. | |
""" | |
if not TYPE_CHECKING: | |
return self.register | |
from typing import overload | |
return overload |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment