Created
August 13, 2024 11:51
-
-
Save lunaluxie/dbcb9da392e18679fd3852ef34579722 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import wraps | |
import inspect | |
from typing import Callable | |
from trycast import isassignable | |
def overload(function) -> Callable: | |
"""Idea: combine multiple functions with different signatures | |
into a single function that dispatches to the correct function | |
based on the arguments. | |
Args: | |
function (Callable): the function being decorated | |
Raises: | |
TypeError: If none of the overloads matched the given arguments when calling the function | |
Returns: | |
Callable: The overloaded function | |
""" | |
verbose = False # you can change this for debugging. | |
# save the old function to a list, so we can avoid overwriting them | |
# and access them later during dispatch. | |
function.overloads = [function] | |
old_function = globals().get(function.__name__) | |
# if one is true, both should always be true | |
if old_function and hasattr(old_function, "overloads"): | |
function.overloads.extend(old_function.overloads) | |
# matching: | |
# step 1 - check if the number of arguments match | |
# step 2 - pair the given parameters to the function parameters | |
# step 3 - check type of each pair. | |
def match(params, args, kwargs, verbose=False) -> bool: | |
if len(params) != len(args) + len(kwargs): | |
if verbose: | |
print( | |
f"Expected {len(params)} arguments but got {len(args) + len(kwargs)}" | |
) | |
return False | |
# gradually remove parameters that have been matched | |
checkable_params = params.copy() | |
# first match positional arguments | |
for i, (k, v) in enumerate(params.items()): | |
if i < len(args): | |
# match positional arguments to the first params | |
if v.annotation == inspect.Parameter.empty: | |
# if the parameter has no annotation, we can't check the type | |
# and assume that the type is correct | |
checkable_params.pop(k) | |
continue | |
elif not isassignable(args[i], v.annotation): | |
if verbose: | |
print(f"Expected {v.annotation} but got {type(args[i])}") | |
return False | |
checkable_params.pop(k) | |
continue | |
# break when we have matched all positional arguments | |
break | |
# match keyword arguments to the remaining params | |
for k, v in kwargs.items(): | |
matched_param = checkable_params.get(k) | |
if matched_param: | |
# match keyword arguments to the remaining params | |
if matched_param.annotation == inspect.Parameter.empty: | |
checkable_params.pop(k) | |
continue | |
elif not isassignable(v, matched_param.annotation): | |
if verbose: | |
print(f"Expected {matched_param.annotation} but got {type(k)}") | |
return False | |
checkable_params.pop(k) | |
continue | |
else: | |
if verbose: | |
print(f"Unexpected keyword argument {k}") | |
return False | |
if len(checkable_params): | |
if verbose: | |
print("Still remaining params, ", checkable_params) | |
return False | |
else: | |
# return True only if all parameters have been matched | |
return True | |
# the wraps decorator copies the name and docstring | |
# and other special attributes of the original function | |
@wraps(function) | |
def wrapper(*args, **kwargs): | |
if verbose: | |
print("Dispatching to one of ", function.overloads) | |
# get the parameters of each overloaded function | |
params = [inspect.signature(f).parameters.copy() for f in function.overloads] | |
# loop through each function candidate and check if it's callable | |
# with the given arguments. Eagerly call and return the first callable function. | |
for i, (f, p) in enumerate(zip(function.overloads, params)): | |
if verbose: | |
print(f"Trying overload ({i}) {f} with params {p}") | |
if match(params[i], args, kwargs, verbose=verbose): | |
return f(*args, **kwargs) | |
raise TypeError("None of the overloads matched the given arguments") | |
# keep the previous functions (overloaded) as we return the wrapper | |
# is not updated by wraps. | |
wrapper.overloads = function.overloads | |
return wrapper | |
class Email: | |
def __init__(self, email: str): | |
self.email = email | |
def __str__(self) -> str: | |
return self.email | |
class PhoneNumber: | |
def __init__(self, phone_number: str): | |
self.phone_number = phone_number | |
def __str__(self) -> str: | |
return self.phone_number | |
class SSN: | |
def __init__(self, ssn: str): | |
self.ssn = ssn | |
def __str__(self) -> str: | |
return self.ssn | |
class PrimaryKey: | |
def __init__(self, pk: int): | |
self.pk = pk | |
def __str__(self) -> str: | |
return str(self.pk) | |
@overload | |
def get_user(email: Email): | |
print("Email:", email) | |
return email | |
@overload | |
def get_user(phone_number: PhoneNumber): | |
print("Phone:", phone_number) | |
return phone_number | |
@overload | |
def get_user(ssn: SSN): | |
print("SSN:", ssn) | |
return ssn | |
class PositiveInteger(): | |
def __init__(self, value: int): | |
if value <= 0: | |
raise ValueError("Value must be positive.") | |
self.value = value | |
def __get__(self): | |
return self.value | |
def __str__(self) -> str: | |
return f"{self.value}" | |
def __sub__(self, other): | |
return PositiveInteger(self.value - other.value) | |
def __rsub__(self, other): | |
return PositiveInteger(other - self.value) | |
def __gt__(self, other): | |
return self.value > other.value | |
def __lt__(self, other): | |
return self.value < other.value | |
class Character(): | |
def __init__(self): | |
self.health = PositiveInteger(100) | |
self.max_health = PositiveInteger(100) | |
def __str__(self) -> str: | |
return f"Character" | |
@overload | |
def attack(target: Character, power: PositiveInteger): | |
target.health -= power | |
print(f"Attacked {target} for {power} damage.") | |
print(f" New health: {target.health}/{target.max_health}.") | |
@overload | |
def attack(target: Character, power: PositiveInteger, blocked: PositiveInteger): | |
if power > blocked: | |
target.health -= power - blocked | |
print(f"Attacked {target} for {power - blocked} damage (blocked {blocked}).") | |
print(f" New health: {target.health}/{target.max_health}.") | |
else: | |
print("Attack was blocked") | |
if __name__ == "__main__": | |
target = Character() | |
attack(target, power=PositiveInteger(10)) | |
# Attacked Character for 10 damage. | |
# New health: 90/100. | |
attack(target, power=PositiveInteger(10), blocked=PositiveInteger(8)) | |
# Attacked Character for 2 damage (blocked 8). | |
# New health: 88/100. | |
attack(target, power=PositiveInteger(10), blocked=PositiveInteger(11)) | |
# Attack was blocked | |
get_user(Email("test@example.com")) # prints: Email: test@example.com | |
get_user(PhoneNumber("123-456-789")) # prints: Phone: 123-456-789 | |
get_user(SSN("123-45-6789")) # prints: SSN: 123-45-6789 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment