Skip to content

Instantly share code, notes, and snippets.

@LiutongZhou
Last active May 13, 2024 15:23
Show Gist options
  • Save LiutongZhou/1d9f8b1231acd8173b30db3e322c67cc to your computer and use it in GitHub Desktop.
Save LiutongZhou/1d9f8b1231acd8173b30db3e322c67cc to your computer and use it in GitHub Desktop.
Universal Decorator
"""Universal Decorator
Universal decorators can decorate functions, classes, bound methods
(class method / instance method) referenced outside of class definition
and descriptors (class methods, static methods) defined inside class definition.
"""
from __future__ import annotations
import inspect
import logging
import os
from abc import ABC, abstractmethod
from contextlib import redirect_stdout
from dataclasses import dataclass, field
from functools import update_wrapper
from typing import Any, Optional, Type, TypeVar
T = TypeVar("T") # generic type
__all__ = ["UniversalDecoratorBase"]
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("[Line: %(lineno)d] - %(message)s"))
logger.addHandler(handler)
@dataclass
class UniversalDecoratorBase(ABC):
"""Base Class for Universal Decorators
Inherit this base class to create universal decorators.
Unlike ordinary decorators that normally only decorate callables,
universal decorators can decorate functions, classes, bound methods
(class method / instance method) referenced outside of class definition
and descriptors (class methods, static methods, instance methods etc.)
defined inside class definition.
When the decorated object is called with (*args, **kwargs), the __get__ and __call__ methods
will be called properly based on the type of the original object to decorate.
If to_decorate is a descriptor (classmethod, staticmethod), instance method defined inside class
definition, or Class.method wrapped outside of class definition, the __get__ method will be
called.
If to_decorate is a function, a class, or instance.bound_method wrapped outside of class
definition, the __call__ method will be called.
Parameters
----------
to_decorate : Any
the original object to decorate. A Callable or descriptor.
Examples
--------
>>> from time import sleep, time
>>> class timeit(UniversalDecoratorBase):
... def __call__(self, *args, **kwargs):
... tic = time()
... result = self.to_decorate(*args, **kwargs)
... toc = time()
... print(f"Time taken: {toc - tic:0.1f} seconds")
... return result
>>> class SuperNeuralNetwork:
... @timeit
... @classmethod
... def from_pretrained(cls, path):
... print(f"Loading model from {path}")
... sleep(0.1)
... return cls()
... @timeit
... def forward(self, x):
... print("Forward pass")
... sleep(0.1)
... return -x
>>> model = SuperNeuralNetwork.from_pretrained("path/to/model")
Loading model from path/to/model
Time taken: 0.1 seconds
>>> model.forward(1)
Forward pass
Time taken: 0.1 seconds
-1
"""
to_decorate: Any
# internal use only, to keep track of the object that called the descriptor
_bound_to: Any = None
# internal use only, to keep track of the name of the decorated object, set by __set_name__
_name: Optional[str] = field(default=None, init=False)
def __post_init__(self):
if callable(self.to_decorate):
update_wrapper(self, self.to_decorate)
if self._name is not None:
self.__name__ = self._name
def __getattr__(self, __name: str) -> Any:
"""Return getattr(self, name)."""
return getattr(self.to_decorate, __name)
def __get__(self, instance: Optional[T], owner: Type[T]) -> UniversalDecoratorBase:
"""Routing the . operator to __get__
Original to_decorate is a descriptor (staticmethod, classmethod), method function
defined inside class, or a bound method (instance method / class method) wrapped
outside of class definition.
"""
_to_decorate = self.to_decorate
# if the original to_decorate is a descriptor, get the actual callable
while inspect.ismethoddescriptor(_to_decorate):
_to_decorate = _to_decorate.__get__(instance, owner)
_signature = inspect.signature(_to_decorate)
# if is a method, bind the method to the instance or class
if "self" in _signature.parameters or "cls" in _signature.parameters:
_to_decorate = _to_decorate.__get__(instance, owner)
decorated_callable = type(self)(
to_decorate=_to_decorate, _bound_to=instance or owner
)
if self._name:
decorated_callable.__name__ = decorated_callable._name = self._name
return decorated_callable
def __set__(self, instance, value):
"""Support changing the original to_decorate via object.decorated = value"""
self.to_decorate = value
def __set_name__(self, owner, name: str):
"""Called on class definition when the descriptor is assigned to a class attribute."""
self.__name__ = self._name = name
def __delete__(self, instance):
"""Delete the descriptor from the instance."""
assert self._name is not None
delattr(type(instance), self._name)
@abstractmethod
def __call__(self, *args, **kwargs):
"""Implement your decoration logic here.
to_decorate is a function (static method included), class,
or bounded method (instance / class method).
"""
to_decorate = self.to_decorate
# _called_from is None if decorating a function, class,
# or instance.bound_method outside of class definition
_called_from = self._bound_to
if inspect.ismethod(
to_decorate # bound method: either a class method or instance method
):
if inspect.isclass(to_decorate.__self__):
logger.debug(
f"decorating a class method: {to_decorate=} "
f"of class: {(class_method_owner := to_decorate.__self__)=}"
)
if _called_from is class_method_owner:
logger.debug("class method called from Class")
elif isinstance(_called_from, class_method_owner):
logger.debug("class method called from instance")
elif _called_from is None:
logger.debug(
"decorating an instance.class_method wrapped outside of class definition "
"and called from an instance"
)
else:
raise NotImplementedError
return to_decorate(*args, **kwargs)
else:
logger.debug(
f"decorating an instance method: {to_decorate=} "
f"bound to instance: {(instance_method_owner := to_decorate.__self__)=}"
)
if _called_from is instance_method_owner:
logger.debug(
"decorating an instance method called from an instance"
)
elif _called_from is None:
logger.debug(
"decorating an instance.method referenced outside of class definition "
"and called from an instance"
)
else:
raise NotImplementedError
return to_decorate(*args, **kwargs)
elif inspect.isclass(to_decorate):
logger.debug(f"decorating a class: {to_decorate=}")
return to_decorate(*args, **kwargs)
elif inspect.isfunction(to_decorate):
if _called_from is None:
logger.debug(f"decorating a function {to_decorate=}")
return to_decorate(*args, **kwargs)
else:
logger.debug(
f"decorating a static method "
f"or Class.instance_method called from {_called_from=}"
)
return to_decorate(*args, **kwargs)
else:
raise NotImplementedError
class UniversalContextDecorator(UniversalDecoratorBase):
"""A base class or mixin that enables context managers to work as universal decorators
Examples
--------
see the `mute` class below
"""
def _recreate_cm(self):
"""Return a recreated instance of self.
Allows an otherwise one-shot context manager like
_GeneratorContextManager to support use as
a decorator via implicit recreation.
This is a private interface just for _GeneratorContextManager.
See issue #11647 for details.
"""
return self
def __call__(self, *args, **kwargs):
with self._recreate_cm():
return self.to_decorate(*args, **kwargs)
class mute(UniversalContextDecorator):
"""A universal decorator that Mute the standard output temporarily.
Examples
--------
>>> class SomeClass:
... @mute
... def some_process(self):
... print("This message won't be printed.")
>>> SomeClass().some_process()
>>> with mute(None):
... print("This message won't be printed either")
"""
def __enter__(self):
self._devnull = open(os.devnull, "w")
self._redirect_stdout_context = redirect_stdout(self._devnull)
self._redirect_stdout_context.__enter__()
def __exit__(self, exc_type, exc_val, exc_traceback):
self._redirect_stdout_context.__exit__(exc_type, exc_val, exc_traceback)
self._devnull.close()
class TestUniversalDecorator:
"""pytest Test Suite for Universal Decorator"""
@pytest.fixture(autouse=True, scope="class")
def custom_decorator(self):
class custom_decorator(UniversalDecoratorBase):
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
return custom_decorator
@pytest.fixture(autouse=True, scope="class")
def MyClass(self, custom_decorator):
class MyClass:
@custom_decorator
def instance_method(self, x) -> int:
"""My instance method"""
logger.debug(self)
logger.debug("instance_method called")
return x
def instance_method2(self, x):
logger.debug("instance_method2 called")
return x
def instance_method3(self, x):
logger.debug("instance_method3 called")
return x
@custom_decorator
@classmethod
def class_method(cls, x):
logger.debug("class_method called")
return cls.__name__ + f"-{x=}"
@classmethod
def class_method2(cls, x):
logger.debug(f"class_method2 called {cls=}")
return cls.__name__ + f"-{x=}"
@classmethod
def class_method3(cls, x):
logger.debug(f"class_method3 called {cls=}")
return cls.__name__ + f"-{x=}"
@custom_decorator
@staticmethod
def static_method(x):
logger.debug("static_method called")
return x
@staticmethod
def static_method2(x):
logger.debug("static_method2 called")
return x
@staticmethod
def static_method3(x):
logger.debug("static_method3 called")
return x
@custom_decorator
@classmethod
def class_method_for_deletion(cls):
return "Delete me"
return MyClass
def test_decorating_class(self):
class custom_decorator(UniversalDecoratorBase):
def __call__(self, *args, **kwargs):
def __init__(self, *args, **kwargs):
self.args = args
for key, value in kwargs.items():
setattr(self, key, value)
self.to_decorate.__init__ = __init__
return super().__call__(*args, **kwargs)
@custom_decorator
class MyClass:
"""Empty Class"""
my_object = MyClass(1, 2, 3, foo="bar") # type: ignore[call-arg]
assert my_object.args == (1, 2, 3) # type: ignore[attr-defined]
assert my_object.foo == "bar" # type: ignore[attr-defined]
assert my_object.__class__.__name__ == "MyClass"
assert MyClass.__name__ == "MyClass"
print(MyClass)
def test_decorating_function(self):
class custom_decorator(UniversalDecoratorBase):
def __call__(self, *args, **kwargs):
first, *rest = args
args = (first + 1, *rest)
return super().__call__(*args, **kwargs)
@custom_decorator
def my_function(x):
return x
assert my_function(1) == 2
@pytest.mark.usefixtures("MyClass")
def test_decorating_instance_method_within_class_definition(self, MyClass):
logger.info("Testing decorating class.instance_method inside class definition")
logger.info(f"{MyClass.instance_method=}")
my_object = MyClass()
logger.info(f"{my_object.instance_method=}")
assert my_object.instance_method(x=1) == 1
assert my_object.instance_method(x=2) == 2
assert MyClass.instance_method(my_object, x=3) == 3
def test_deleting_decorated_method(self, MyClass):
my_object = MyClass()
del my_object.instance_method
assert not hasattr(my_object, "instance_method")
assert hasattr(MyClass, "class_method_for_deletion")
assert MyClass.class_method_for_deletion() == "Delete me"
del MyClass.class_method_for_deletion
assert not hasattr(MyClass, "class_method_for_deletion")
@pytest.mark.usefixtures("MyClass", "custom_decorator")
def test_decorating_instance_method_outside_class_definition(
self, custom_decorator, MyClass
):
logger.info("Testing decorating class.instance_method outside class definition")
MyClass.instance_method2 = custom_decorator(MyClass.instance_method2)
logger.info(MyClass.instance_method2)
my_object = MyClass()
assert my_object.instance_method2(x=1) == 1
assert my_object.instance_method2(x=2) == 2
assert MyClass.instance_method2(my_object, x=3) == 3
logger.info(
"Testing decorating object.instance_method outside class definition"
)
my_object.instance_method3 = custom_decorator(my_object.instance_method3)
assert my_object.instance_method3(x=1) == 1
assert my_object.instance_method3(x=2) == 2
assert MyClass.instance_method3(my_object, x=3) == 3
@pytest.mark.usefixtures("MyClass")
def test_decorating_class_method_within_class_definition(self, MyClass):
logger.debug(f"{MyClass.class_method=}")
logger.debug(f"{MyClass().class_method=}")
assert MyClass.class_method(x=1) == "MyClass-x=1"
assert MyClass().class_method(x=1) == "MyClass-x=1"
@pytest.mark.usefixtures("MyClass", "custom_decorator")
def test_decorating_class_method_outside_class_definition(
self, custom_decorator, MyClass
):
MyClass.class_method2 = custom_decorator(MyClass.class_method2)
assert MyClass.class_method2(x=1) == "MyClass-x=1"
assert MyClass().class_method2(x=1) == "MyClass-x=1"
my_class = MyClass()
my_class.class_method3 = custom_decorator(my_class.class_method3)
assert MyClass.class_method3(x=1) == "MyClass-x=1"
assert my_class.class_method3(x=1) == "MyClass-x=1"
@pytest.mark.usefixtures("MyClass")
def test_decorating_static_method_within_class_definition(self, MyClass):
logger.debug(f"{MyClass.static_method=}")
logger.debug(f"{MyClass().static_method=}")
assert MyClass.static_method(x=1) == 1
assert MyClass().static_method(x=1) == 1
def test_decorating_static_method_outside_class_definition(
self, custom_decorator, MyClass
):
MyClass.static_method2 = custom_decorator(MyClass.static_method2)
assert MyClass.static_method2(x=1) == 1
assert MyClass().static_method2(x=1) == 1
my_class = MyClass()
my_class.static_method3 = custom_decorator(my_class.static_method3)
assert MyClass.static_method3(x=1) == 1
assert my_class.static_method3(x=1) == 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment