Skip to content

Instantly share code, notes, and snippets.

@ItsDrike
Last active October 4, 2021 06:50
Show Gist options
  • Save ItsDrike/afc168f5fa592cb50c82a8ade7218c7f to your computer and use it in GitHub Desktop.
Save ItsDrike/afc168f5fa592cb50c82a8ade7218c7f to your computer and use it in GitHub Desktop.
Python Function Overloads
#!/usr/bin/env python3
import inspect
from typing import Hashable, Any, Optional, Callable, TypeVar
# Make a unique object for unambiguously distinguishing some missing values.
# (Simply using None wouldn't work, because the actual value could have been None)
_MISSING = object()
# Define a callable type var for the @overload decorator function, to satisfy
# static type checkers and return the same function as the passed one
T = TypeVar("T", bound=Callable)
class NoMatchingOverload(TypeError):
"""
This exception is raised when a method which has multiple overloads
was called without matching any of the argument overloads.
"""
class MultipleMatchingOverloads(TypeError):
"""
This exception is raised when a method which has multiple overloads
was called with arguments matching multiple of these overloads and
we weren't able to determine which of these overloads should be used.
"""
class OverloadList(list):
"""
Make a unique list class only for storing overload values.
Just using regular list wouldn't be sufficient because we need to
unambiguously distinguish it from regular lists.
"""
class OverloadDict(dict):
"""
A custom dictionary class that allows the existence of multiple
items under the same key if they have the __overload__ attribute
set to True.
"""
def __setitem__(self, key: Hashable, value: Any) -> None:
"""
Override the setitem method to handle for setting multiple
items with the same key if they have the value.__overload__
attribute set to True.
"""
previous_value = self.get(key, _MISSING)
is_overloaded = getattr(value, "__overload__", False)
if previous_value is _MISSING:
insert_value = OverloadList([value]) if is_overloaded else value
super().__setitem__(key, insert_value)
elif isinstance(previous_value, OverloadList):
if not is_overloaded:
raise ValueError(
"Can't override existing overloaded value with "
"non-overloaded value (forgot @overload?)"
)
previous_value.append(value)
else:
if is_overloaded:
raise ValueError(
"Can't set override value for a key which already "
"contains non-overloaded value (forgot @overload?)"
)
super().__setitem__(key, value)
class BoundOverloadDispatcher:
"""
This class is the object in place of the overloaded functions,
when it's called we decide which overload to use based on the
arguments for that call.
"""
def __init__(self, instance: object, owner_cls: type[object], name: str, overload_list: OverloadList):
self.instance = instance
self.owner_cls = owner_cls
self.name = name
self.overload_list = overload_list
self.signatures = [inspect.signature(f) for f in overload_list]
def __call__(self, *args, **kwargs):
"""
Once the overloaded method is called, try find the function with
a signature that matches the passed call arguments.
- If multiple functions like these are found, raise MultipleMatchingOverloads.
- If we didn't find any functions that match the arguments, try to find
a next in line matching method with super(), if we don't find a matching
function there either, raise NoMatchingOverload.
"""
try:
f = self.best_match(*args, **kwargs)
except NoMatchingOverload:
pass
else:
return f(self.instance, *args, **kwargs)
# No matching overload was found in the owner class
# try to check the next in line
super_instance = super(self.owner_cls, self.instance)
super_call = getattr(super_instance, self.name, _MISSING)
if super_call is not _MISSING:
return super_call(*args, **kwargs) # type: ignore
else:
raise NoMatchingOverload()
def best_match(self, *args, **kwargs):
"""
Attempt to find the best overloaded method that matches given arguments.
If we find multiple methods that all match them, raise MultipleMatchingOverloads
and if we don't find any, raise NoMatchingOverload.
"""
matching_functions = []
for f, sig in zip(self.overload_list, self.signatures):
try:
bound_args = sig.bind(self.instance, *args, **kwargs)
except TypeError:
pass # missing/extra/unexpected args or kwargs
else:
bound_args.apply_defaults()
if self._signature_matches(sig, bound_args):
matching_functions.append(f)
if len(matching_functions) == 0:
raise NoMatchingOverload()
if len(matching_functions) == 1:
return matching_functions[0]
raise MultipleMatchingOverloads()
@staticmethod
def _type_hint_matches(obj, hint):
"""
Check if the type hint matches the given object.
NOTE: This only works with concrete types, not something like Optional.
"""
return hint is inspect.Parameter.empty or isinstance(obj, hint)
@classmethod
def _signature_matches(cls, sig: inspect.Signature, bound_args: inspect.BoundArguments) -> bool:
"""Check if all of the type hints of the signature match the bound arguments."""
for name, arg in bound_args.arguments.items():
param = sig.parameters[name]
hint = param.annotation
if not cls._type_hint_matches(arg, hint):
return False
return True
class OverloadDescriptor:
"""
A descriptor in place of the overloaded methods that is initialized
from the metaclass with the list of all overloads for given function.
Once we try to access this overloaded fucntion, we return a
BoundOverloadDispatcher that will decide which of the overloads should
be picked.
We're using a descriptor here to be able to capture the instance along
with the attempt to access the overloaded fucntion. This is important
because we then use this instance in the BoundOverloadDispatcher when
we don't find the correct overloads to check for them with super().
"""
def __set_name__(self, owner: type[object], name: str) -> None:
"""
The descriptor protocol adds this method to make it simple
to obtain the name of the attribute set to this descriptor.
"""
self.owner = owner
self.name = name
def __init__(self, overload_list: OverloadList) -> None:
"""
The descriptor is initialized from the metaclass and receives
a list of all overload functions.
"""
if not isinstance(overload_list, OverloadList):
raise TypeError("Must use OverloadList.")
if not overload_list:
raise ValueError("The overload_list can't be empty")
self.overload_list = overload_list
def __repr__(self):
"""This will be the repr of all overloaded functions."""
return f"{self.__class__.__qualname__}({self.overload_list!r}))"
def __get__(self, instance: object, owner: Optional[type[object]] = None):
"""
This method gets called whenever the overloaded method is accessed.
This mimics the default python behavior where accessing class.function
would give you the function object, but accessing instance.function
will give you a bound method that stores the function object and
auto-passes the self argument once it's callled, in our case, accessing
the overloaded function from a class gives you this descriptor and
accessing from an instance will returh a BoundOverloadDispatcher.
"""
# If the descriptor is accessed from the class directly, rather than
# from an initialized object, we return this descriptor (self)
if instance is None:
return self
# TODO: Consider using a dict cache with a composite hash of the
# values passed into the initialization of BoundOverloadDispatcher
# to avoid having to initialize a this class every thime it is accessed
return BoundOverloadDispatcher(
instance, self.owner,
self.name, self.overload_list
)
class OverloadMeta(type):
"""A metaclass that allows a class to have overloads for it's methods."""
@classmethod
def __prepare__(cls, name: str, bases: list[type[object]]) -> OverloadDict:
"""
This is the method which returns the default empty dictionary
which will then be used for running exec on the class body as
'locals' dict.
We override this method to return our custom dictionary, that
will be able to support setting multiple functions with the
same name if they have function.__overload__ set to True.
"""
return OverloadDict()
def __new__(mcls, name: str, bases, namespace: OverloadDict, **kwargs): # type: ignore
"""
Override the class creation and change all captured overload lists
in the given namespace from __prepare__ to OverloadDescriptors,
which when accessed will return a BoundOverloadDispatcher that is
able to figure out which overload to use depending on the passed
call arguments and the signatures of the overloaded fucntions.
"""
overload_namespace = {
key: OverloadDescriptor(val) if isinstance(val, OverloadList) else val
for key, val in namespace.items()
}
return super().__new__(mcls, name, bases, overload_namespace, **kwargs)
def overload(f: T) -> T:
"""
The overload decorator.
By using this decorator in a class that uses the OverloadMeta
metaclass during it's creation, a method can have the same name
with a different amount of attributes or different type-hits for
those attributes and whenever such a method is called, the
appropriate method will be picked from all of the specified overloads
depending on the passed arguments.
This decorator alone doesn't use any special logic, all of the logic
is handled by the metaclass, this only specifies the decorated function
as an overloaded function and allows the metaclass to handle it
differently than regular methods.
"""
f.__overload__ = True
return f
class Overloadable(metaclass=OverloadMeta):
"""
Regular class that other classes can inherit from to also
inherit the OverloadMeta metaclass along with it, that allows
for the overloads to be made.
"""
class Example(Overloadable):
@overload
def bar(self, x: int): # type: ignore # noqa: F811
print(f"Called Example.bar int overload: {x=!r}")
@overload
def bar(self, x: str): # type: ignore # noqa: F811
print(f"Called Example.bar str overload: {x=!r}")
@overload
def bar(self, x: int, y: int): # type: ignore # noqa: F811
print(f"Called Example.bar two argument overload: {x=!r} {y=!r}")
def foobar(self, x: str):
print(f"Called Example.foobar regular method: {x=!r}")
if __name__ == "__main__":
foo = Example()
foo.bar(1) # type: ignore
foo.bar("hi") # type: ignore
foo.bar(1, 8) # type: ignore
foo.foobar("hello")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment