Created
July 12, 2024 12:50
-
-
Save CGamesPlay/d3f72fca787b16b879efb07f1ba46d7a to your computer and use it in GitHub Desktop.
Fully-typed Python decorator for functions, methods, staticmethods, and classmethods.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Showcase a fully-typed decorator that can be applied to functions, methods, | |
staticmethods, and classmethods. When combining the decorator with staticmethod | |
and classmethod, it's important to put the decorator directly before the | |
staticmethod of classmethod. | |
This example uses a few heuristics which are not 100% accurate to type | |
staticmethod and classmethod. Specifically, if the first argument to a method | |
is a type object, the decorator always assumes that the method is a | |
classmethod; and if the first argument to a method is an instance of the class | |
it's called on, the decorator always assumes that the method is a regular | |
method. In both cases, the problem is only with the types and the runtime | |
behavior is correct. | |
It's possible to fix the staticmethod issue by simply using the included | |
staticdecorator. To fix regular functions being interpreted as classmethods, | |
it's necessary to remove the first overload from the decorator function, and | |
then force all classmethods to use the included classdecorator. | |
""" | |
import types | |
import unittest | |
from typing import ( | |
Any, | |
Callable, | |
Concatenate, | |
Generic, | |
Never, | |
ParamSpec, | |
TypeVar, | |
overload, | |
reveal_type, | |
) | |
P = ParamSpec("P") | |
BoundP = ParamSpec("BoundP") | |
R = TypeVar("R") | |
S = TypeVar("S") | |
class Decorator(Generic[S, P, BoundP, R]): | |
def __init__(self, f: Callable[P, R]): | |
self.f = f | |
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: | |
return self.f(*args, **kwargs) | |
@overload | |
def __get__( | |
self, instance: S, owner: type | |
) -> "Decorator[S, BoundP, BoundP, R]": ... | |
@overload | |
def __get__(self, instance: Any, owner: type) -> "Decorator[S, P, BoundP, R]": ... | |
def __get__(self, instance: Any, owner: Any = None) -> Any: | |
# Overload 1 is for bound methods; overload 2 is for unbound functions. | |
# | |
# We special case support for staticmethod and classmethod here. | |
if isinstance(self.f, staticmethod) or isinstance(self.f, classmethod): | |
return self.f.__get__(instance, owner) | |
if instance is None: | |
return self | |
return Decorator(types.MethodType(self.f, instance)) | |
@overload | |
def decorator(f: Callable[Concatenate[type[S], P], R]) -> Decorator[Never, P, P, R]: ... | |
@overload | |
def decorator( # pyright: ignore[reportOverlappingOverload] | |
f: Callable[Concatenate[S, P], R] | |
) -> Decorator[S, Concatenate[S, P], P, R]: ... | |
@overload | |
def decorator( # pyright: ignore[reportOverlappingOverload] | |
f: Callable[P, R] | |
) -> Decorator[Never, P, P, R]: ... | |
def decorator(f: Any) -> Any: | |
# Overload 1 is a heuristic for classmethods which tags them to not never | |
# accept self/cls parameters. If the first parameter is a type object, this | |
# will be incorrect. | |
# Overload 2 is a heuristic for bound methods, which works by assuming all | |
# functions with more than 1 argument can be bound when accessing through | |
# dot notation. | |
# Overload 3 detects functions with 0 parameters. | |
return Decorator(f) | |
def staticdecorator(f: Callable[P, R]) -> Decorator[Never, P, P, R]: | |
return Decorator(f) | |
def classdecorator( | |
f: Callable[Concatenate[type[S], P], R] | |
) -> Decorator[Never, P, P, R]: | |
return Decorator(f) # type: ignore | |
@decorator | |
def func() -> None: | |
print("in func()") | |
@decorator | |
def func_param(val: int) -> None: | |
print(f"in func_param({val})") | |
@decorator | |
def func_typevar(val: type) -> None: | |
print("in func_typevar") | |
class Class: | |
@decorator | |
def method(self) -> None: | |
assert isinstance(self, Class) | |
print("in Class.method()") | |
@decorator | |
def method_param(self, val: int) -> None: | |
assert isinstance(self, Class) | |
print(f"in Class.method_param({val})") | |
@decorator | |
@staticmethod | |
def static_method() -> None: | |
print("in Class.static_method()") | |
@decorator | |
@staticmethod | |
def static_method_param(val: int) -> None: | |
print(f"in Class.static_method_param({val})") | |
@decorator | |
@staticmethod | |
def static_method_typevar(val: "Class") -> None: | |
print(f"in Class.static_method_param({val})") | |
@decorator | |
@classmethod | |
def class_method(cls) -> None: | |
print("in Class.class_method()") | |
@decorator | |
@classmethod | |
def class_method_param(cls, val: int) -> None: | |
print(f"in Class.class_method_param({val})") | |
class TestCases(unittest.TestCase): | |
def test_func(self) -> None: | |
reveal_type(func) | |
func() | |
def test_func_param(self) -> None: | |
reveal_type(func_param) | |
func_param(1) | |
def test_method(self) -> None: | |
reveal_type(Class.method) | |
reveal_type(Class().method) | |
Class.method(Class()) | |
Class().method() | |
def test_method_param(self) -> None: | |
reveal_type(Class.method_param) | |
reveal_type(Class().method_param) | |
Class.method_param(Class(), 1) | |
Class().method_param(1) | |
def test_static_method(self) -> None: | |
reveal_type(Class.static_method) | |
reveal_type(Class().static_method) | |
Class.static_method() | |
Class().static_method() | |
def test_static_method_param(self) -> None: | |
reveal_type(Class.static_method_param) | |
reveal_type(Class().static_method_param) | |
Class.static_method_param(1) | |
Class().static_method_param(1) | |
def test_class_method(self) -> None: | |
reveal_type(Class.class_method) | |
reveal_type(Class().class_method) | |
Class.class_method() | |
Class().class_method() | |
def test_class_method_param(self) -> None: | |
reveal_type(Class.class_method_param) | |
reveal_type(Class().class_method_param) | |
Class.class_method_param(1) | |
Class().class_method_param(1) | |
def test_typing_failures(self) -> None: | |
reveal_type(func_typevar) | |
reveal_type(Class().static_method_typevar) | |
func_typevar(int) # type: ignore | |
Class().static_method_typevar(Class()) # type: ignore | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thought about this a bit more. This limitation I mentioned is necessary in general.
I don't think it's possible to correctly statically analyze self-binding without breaking currently valid behaviors in Python. The problem is that the signature of the method changes dynamically at runtime. Consider the following example: Code sample in pyright playground
This is a runtime TypeError in a fully typed Python program. Which of these lines should the type checker reject?