Created
September 5, 2020 01:03
-
-
Save fakufaku/95be06e60b0e418c2cefebd21d30c848 to your computer and use it in GitHub Desktop.
Type dispatch in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Dict, List | |
def match_signature(signature, args, kwargs): | |
ret = True | |
args = list(args) | |
kwargs = kwargs.copy() | |
for name, type_ in signature.__annotations__.items(): | |
if len(args) > 0: | |
# we match the non-keyword arguments first, in order | |
if type(args[0]) == type_: | |
args.pop(0) | |
else: | |
return False | |
elif name in kwargs: | |
# then we match what is left with kw args | |
kwargs.pop(name) | |
if len(kwargs) == 0: | |
return True | |
else: | |
return False | |
class typedispatch: | |
""" | |
This class is a decorator that allows to overload functions | |
with different type signatures and dynamically call them | |
depending on their input arguments | |
""" | |
# class-wide dict to store functions | |
functions: Dict[str, List[Callable]] = {} | |
def __init__(self, func): | |
self.name = func.__name__ | |
if self.name not in self.functions: | |
self.functions[self.name] = [] | |
self.functions[self.name].append(func) | |
def __call__(self, *args, **kwargs): | |
for func in self.functions[self.name]: | |
if match_signature(func, args, kwargs): | |
return func(*args, **kwargs) | |
raise ValueError("Matching type signature not found") | |
if __name__ == "__main__": | |
@typedispatch | |
def f(x: int, y: int) -> int: | |
return x + y | |
@typedispatch | |
def f(name: str, age: int) -> str: | |
return name + f" ({age} y.o.)" | |
@typedispatch | |
def sub(x: int, y: int) -> int: | |
return x - y | |
@typedispatch | |
def sub(a: str, b: str) -> str: | |
return a + b |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment