Skip to content

Instantly share code, notes, and snippets.

@Aldlevine
Last active September 5, 2023 02:36
Show Gist options
  • Save Aldlevine/898a1d04bba27bd7c89371462f2e0132 to your computer and use it in GitHub Desktop.
Save Aldlevine/898a1d04bba27bd7c89371462f2e0132 to your computer and use it in GitHub Desktop.
Python MultiDispatch
from typing import Any
from multi_dispatch import MultiDispatch
multi = MultiDispatch().overload
@multi
def func() -> tuple[str, str]:
return "EMPTY", ""
@multi
def func(__a: int) -> tuple[str, int]:
return "__a: int", __a + 10
@multi
def func(__a: str) -> tuple[str, str]:
return "__a: str", __a * 2
@multi
def func(__a: str, *__extra: None) -> tuple[str, str]:
return "__a: str, *__extra: None", f"{__a}, {__extra}"
@multi
def func(__a: str, *__extra: Any) -> tuple[str, str]:
return "__a: str, *__extra: Any", f"{__a}, {__extra}"
@multi
def func(__a: str, **kwargs: Any) -> tuple[str, dict[str, Any]]:
return "__a: str, **kwargs: Any", {"__a": __a, **kwargs}
@multi
def func(__a: int, *, b: bool) -> tuple[str, dict[str, Any]]:
return "__a: str, b: bool", {"__a": __a, "b": b}
# fmt: off
print(func()) # >> ('EMPTY', '')
print(func(10)) # >> ('__a: int', 20)
print(func("Hello world. ")) # >> ('__a: str', 'Hello world. Hello world. ')
print(func("Hello", world=True)) # >> ('__a: str, **kwargs: Any', {'__a': 'Hello', 'world': True})
print(func("a", "b", "c", "d")) # >> ('__a: str, *__extra: Any', "a, ('b', 'c', 'd')")
print(func("a", None, None, None)) # >> ('__a: str, *__extra: None', "a, (None, None, None)")
print(func(1, b=True)) # >> ('__a: str, *__extra: None', "a, (None, None, None)")
# func({"this": "fails"}) # >> TypeError: MultiDispatch func() arguments do not match any overload
# fmt: on
from multi_dispatch import MultiDispatch
multi = MultiDispatch().overload
class Greeter:
@multi
def greet(self) -> str:
return "Hello"
@multi
def greet(self, name: str) -> str:
return f"{self.greet()}, {name}."
@multi
def greet(self, name: str, age: float) -> str:
return f"{self.greet(name)} I hear you are {age} years old?"
inst = Greeter()
# fmt: off
print(inst.greet()) # >> Hello
print(inst.greet("Danley")) # >> Hello, Danley.
print(inst.greet("Staniel", 175)) # >> Hello, Staniel. I hear you are 175 years old?
print(inst.greet(name="Queztalcoatl", age=float("-inf"))) # >> Hello, Queztalcoatl. I hear you are -inf years old?
# inst.greet(1, 2, 3, 4) # >> TypeError: MultiDispatch Greeter.greet() arguments do not match any overload
# fmt: on
from functools import wraps
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Callable, cast
__all__ = ["MultiDispatch"]
class MultiDispatch:
__overload_registry: dict[str, list[Callable[..., Any]]]
__dispatch_registry: dict[str, Callable[..., Any]]
def __init__(self) -> None:
self.__overload_registry = {}
self.__dispatch_registry = {}
def __key_name(self, __fn: Callable[..., Any]) -> str:
return f"{__fn.__module__}.{__fn.__qualname__}"
def _is_assignable(self, value: Any, typ: Any) -> bool:
"""Override to provide custom assignability criteria
Args:
value: an argument to check for assignability
typ: a type to check against
Returns:
True if assignable
"""
# Any or no type should match anything
if typ is Any or typ is Parameter.empty:
return True
# Because None is a valid type parameter, but isn't a class
if typ is None:
return value is None
# Because ints should be assignable to floats
if typ is float and isinstance(value, int):
return True
# instances should be assignable to classes they're instances of
try:
if isinstance(value, typ):
return True
except:
...
return False
def __call__(self, __fn: Callable[..., Any], *__args: Any, **__kwargs: Any) -> Any:
"""Iterates through overloads registered to `__fn` and attempts to dispatch.
Args:
__fn: The function to dispatch for.
Raises:
`TypeError`: No valid overload
Returns:
The return value of the dispatched overload.
"""
for fn in self.__overload_registry[self.__key_name(__fn)]:
sig = signature(fn)
try:
bound_args = sig.bind(*__args, **__kwargs)
# type check args
params_list = list(sig.parameters.values())
param_index = 0
for arg in __args:
if param_index >= len(params_list):
raise TypeError()
param = params_list[param_index]
if param.kind in (param.KEYWORD_ONLY, param.VAR_KEYWORD):
raise TypeError()
if not self._is_assignable(arg, param.annotation):
raise TypeError()
if not param.kind == param.VAR_POSITIONAL:
param_index += 1
# do I have **kwargs?
has_var_kwarg: bool = False
var_kwarg_type: type | None = None
for param in params_list:
if param.kind == param.VAR_KEYWORD:
has_var_kwarg = True
var_kwarg_type = param.annotation
# type check kwargs
params = sig.parameters
for kw, arg in __kwargs.items():
if kw in params:
if not self._is_assignable(arg, params[kw].annotation):
raise TypeError()
continue
if has_var_kwarg:
if not self._is_assignable(arg, var_kwarg_type):
raise TypeError()
continue
raise TypeError()
# success!
return fn(*bound_args.args, **bound_args.kwargs)
except TypeError:
...
raise TypeError(
f"MultiDispatch {__fn.__qualname__}() arguments do not match any overload"
)
def register[**TT, T](self, fn: Callable[TT, T]) -> Callable[TT, T]:
"""Register an overload with the dispatcher
Args:
fn: The overload to register
Returns:
The dispatcher for `fn`
"""
key_name = self.__key_name(fn)
self.__overload_registry.setdefault(key_name, []).append(fn)
result = self.__dispatch_registry.get(key_name)
if result is None:
@wraps(fn)
def __dispatch(*args: TT.args, **kwargs: TT.kwargs) -> T:
return self(fn, *args, **kwargs)
result = self.__dispatch_registry[key_name] = __dispatch
return cast(Callable[TT, Any], result)
@property
def overload(self):
"""The overload decorator. When type checking this is `typing.overload`.
At runtime this is `self.register`. This allows type checkers to treat it
as the standard `typing.overload`.
It's not perfect as declarations will complain
"marked as overload but includes implementation" and
"marked as overload but no implementation provided".
It works correctly at runtime, and call site type checking
(at least with pyright/pylance) works.
Returns:
The overload decorator.
Intentionally untyped to force type checkers to infer.
"""
if not TYPE_CHECKING:
return self.register
from typing import overload
return overload
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment