Skip to content

Instantly share code, notes, and snippets.

@anatoly-kussul
Last active March 1, 2024 16:27
Show Gist options
  • Save anatoly-kussul/f2d7444443399e51e2f83a76f112364d to your computer and use it in GitHub Desktop.
Save anatoly-kussul/f2d7444443399e51e2f83a76f112364d to your computer and use it in GitHub Desktop.
Python sync-async decorator factory
class SyncAsyncDecoratorFactory:
"""
Factory creates decorator which can wrap either a coroutine or function.
To return something from wrapper use self._return
If you need to modify args or kwargs, you can yield them from wrapper
"""
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
# This is for using decorator without parameters
if len(args) == 1 and not kwargs and (inspect.iscoroutinefunction(args[0]) or inspect.isfunction(args[0])):
instance.__init__()
return instance(args[0])
return instance
class ReturnValue(Exception):
def __init__(self, return_value):
self.return_value = return_value
@contextmanager
def wrapper(self, *args, **kwargs):
raise NotImplementedError
@classmethod
def _return(cls, value):
raise cls.ReturnValue(value)
def __call__(self, func):
@wraps(func)
def call_sync(*args, **kwargs):
try:
with self.wrapper(*args, **kwargs) as new_args:
if new_args:
args, kwargs = new_args
return self.func(*args, **kwargs)
except self.ReturnValue as r:
return r.return_value
@wraps(func)
async def call_async(*args, **kwargs):
try:
with self.wrapper(*args, **kwargs) as new_args:
if new_args:
args, kwargs = new_args
return await self.func(*args, **kwargs)
except self.ReturnValue as r:
return r.return_value
self.func = func
return call_async if inspect.iscoroutinefunction(func) else call_sync
@falkben
Copy link

falkben commented Aug 3, 2020

Any chance you could show an example with the self._return being used? Since this is being used as a decorator, my thinking is that you still need to return func, so maybe you add some attribute to func in your wrapper and then return that with self._return(func_w_attribute). But then I don't know how to easily pull out that attribute. But I think I'm probably missing something.

@anatoly-kussul
Copy link
Author

anatoly-kussul commented Aug 3, 2020

@falkben You just have to call it with desired return value inside wrapper contextmanager

It is used something like this:

class ExceptionLogger(SyncAsyncDecoratorFactory):
    def __init__(self, default=None):
        self.default_value = default

    @contextmanager
    def wrapper(self, *args, **kwargs):
        try:
            yield
        except asyncio.CancelledError:
            raise
        except Exception as e:
            logging.error('Error log')
            self._return(self.default_value)

then you can use this decorator as following:

@ExceptionLogger
def sync_foo():
    print('in sync')
    _ = 1/0


@ExceptionLogger(default='default value')
def sync_foo_with_default():
    print('in sync with default')
    _ = 1/0


@ExceptionLogger
async def async_foo():
    print('in async')
    _ = 1/0


@ExceptionLogger(default='default value')
async def async_foo_with_default():
    print('in async with default')
    _ = 1/0


async def test_async():
    result_3 = await async_foo()
    print(f'Async foo result - {result_3}')
    result_4 = await async_foo_with_default()
    print(f'Async foo result default - {result_4}')


def example():
    result_1 = sync_foo()
    print(f'Sync foo result - {result_1}')
    result_2 = sync_foo_with_default()
    print(f'Sync foo with default result - {result_2}')
    asyncio.run(test_async())

which will give you output:

>>> example()
in sync
ERROR:root:Error log
Sync foo result - None
in sync with default
ERROR:root:Error log
Sync foo with default result - default value
in async
ERROR:root:Error log
Async foo result - None
in async with default
ERROR:root:Error log
Async foo result default - default value
>>>

@falkben
Copy link

falkben commented Aug 4, 2020

I see, thanks! I misunderstood and thought it was returning data into the function it was decorating, which would have been quite interesting.

@falkben
Copy link

falkben commented Oct 23, 2020

First off, just want to say this is a really clever use of Exceptions to break out of the contextmanager. Also, I'm not sure I ever would have figured out exactly how __new__ __init__ could be used to allow arguments to the decorator, so thank you so much for creating this gist.

By the way, I've recently had a need for this in decorating classes, and it needed a slight modification.

When a class method gets decorated with your version of a decorator, you get a TypeError because the wrapper gets called with two arguments named "self" (one is in the kwargs).

I've gotten around this as follows:

def wrapper(dec_self, *args, self=None, **kwargs):

Now, the wrapper can still access instance variables of the decorated function (e.g. self.attrib1) and can still access it's own instance variables (e.g. dec_self._return).

And as long as your decorator implementation doesn't expect something in self, you can use the same factory class without having to worry too much, since self, from the decorated function, gets passed in as a kwarg.

Again, thanks!

@falkben
Copy link

falkben commented Oct 27, 2020

Another thing to be aware of with this implementation: the wrapper context manager is a sync function (not an async context manager, which do exist). So, what this means, I think, is that everything in your wrapper is executed in the main event loop, when this decorates an async function.

If you made some database call, external http request, or file access, inside the decorator (wrapper here) then it could stall your event loop.

I think that if you were doing this sort of thing in your decorator, you could possibly, inside of call_async, run the context manager inside of a ThreadPoolExecutor to avoid locking the event loop.

Of course, if you are okay with duplicating code, you could also use the async context manager instead, but in that case, I think I'd just skip the context manager.

@NicoAdrian
Copy link

NicoAdrian commented Nov 30, 2021

I've been trying to use your code to avoir code duplication for sync and async functions decorators but couldn't get it to work.
I'm trying to implement a Time Bounbed LRU cache decorator (similar as functools.lru_cachebut with time expiring keys). Below my attempt:

from functools import wraps
import asyncio
from contextlib import contextmanager
import collections
import time


class SyncAsyncDecoratorFactory:
    @contextmanager
    def wrapper(self, *args, **kwargs):
        raise NotImplementedError("Please call this method from subclasses")

    def __call__(self, func):
        @wraps(func)
        def sync_wrapper(*args, **kwargs):
            with self.wrapper(*args, **kwargs) as res:
                if res is None:
                    return func(*args, **kwargs)

        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            with self.wrapper(*args, **kwargs) as res:
                return await func(*args, **kwargs)

        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper


class TimedLRU(SyncAsyncDecoratorFactory):
    def __init__(self, max_size=128, max_age=30):
        super().__init__()
        self.max_size = max_size
        self.max_age = max_age
        self._cache = collections.OrderedDict()
        self._sentinel = object()

    @contextmanager
    def wrapper(self, *args, **kwargs):
        k = args + (self._sentinel,) + tuple(sorted(kwargs.items()))
        if k in self._cache:
            print("hit")
            self._cache.move_to_end(k)
            res, ts = self._cache[k]
            if time.time() - ts <= self.max_age:
                yield
                return res
        res = yield
        self._cache[k] = (res, time.time())
        if len(self._cache) > self.max_size:
            self._cache.popitem(0)
        return res


@TimedLRU()
def foobar(s):
    print("in decorated function")
    return "foo %s bar" % s


a = foobar("hey")
print(a)
b = foobar("hey")
print(b)

Any help how to achieve this ? the _cacheattribute has Noneas a value and in decorated functionis printed twice instead of once.

@anatoly-kussul
Copy link
Author

anatoly-kussul commented Dec 7, 2021

@NicoAdrian

Hey, sorry for late response.
The reason you are not getting what you want is that you can't directly return from context manager.

Here is slightly modified version that does what you want, not the best solution probably, just my first look at that, but you should get an idea:

from functools import wraps
import asyncio
from contextlib import contextmanager
import collections
import time
import inspect

class SyncAsyncDecoratorFactory:
    """
    Factory creates decorator which can wrap either a coroutine or function.
    To return something from wrapper use self._return
    If you need to modify args or kwargs, you can yield them from wrapper
    """
    def __new__(cls, *args, **kwargs):
        instance = super().__new__(cls)
        # This is for using decorator without parameters
        if len(args) == 1 and not kwargs and (inspect.iscoroutinefunction(args[0]) or inspect.isfunction(args[0])):
            instance.__init__()
            return instance(args[0])
        return instance

    class ReturnValue(Exception):
        def __init__(self, return_value):
            self.return_value = return_value

    @contextmanager
    def wrapper(self, *args, **kwargs):
        raise NotImplementedError

    @classmethod
    def _return(cls, value):
        raise cls.ReturnValue(value)

    def __call__(self, func):
        @wraps(func)
        def call_sync(*args, **kwargs):
            try:
                with self.wrapper(*args, **kwargs) as new_args:
                    if new_args:
                        args, kwargs = new_args
                    self._return(self.func(*args, **kwargs))
            except self.ReturnValue as r:
                return r.return_value

        @wraps(func)
        async def call_async(*args, **kwargs):
            try:
                with self.wrapper(*args, **kwargs) as new_args:
                    if new_args:
                        args, kwargs = new_args
                    self._return(await self.func(*args, **kwargs))
            except self.ReturnValue as r:
                return r.return_value

        self.func = func
        return call_async if inspect.iscoroutinefunction(func) else call_sync


class TimedLRU(SyncAsyncDecoratorFactory):
    def __init__(self, max_size=128, max_age=30):
        super().__init__()
        self.max_size = max_size
        self.max_age = max_age
        self._cache = collections.OrderedDict()
        self._sentinel = object()

    @contextmanager
    def wrapper(self, *args, **kwargs):
        k = args + (self._sentinel,) + tuple(sorted(kwargs.items()))
        if k in self._cache:
            print("hit")
            self._cache.move_to_end(k)
            res, ts = self._cache[k]
            if time.time() - ts <= self.max_age:
                self._return(res)
        try:
            yield
        except self.ReturnValue as r:
            self._cache[k] = (r.return_value, time.time())
            if len(self._cache) > self.max_size:
                self._cache.popitem(0)
            raise


@TimedLRU()
def foobar(s):
    print("in decorated function")
    return "foo %s bar" % s


a = foobar("hey")
print(a)
b = foobar("hey")
print(b)```

@NicoAdrian
Copy link

I see, thanks !

@albertmenglongli
Copy link

This class is cool! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment