Created
October 15, 2020 08:56
-
-
Save ods/2168ea823bd3075e418ca229c6283681 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from contextlib import AsyncExitStack, ExitStack | |
import contextvars | |
import functools | |
import inspect | |
import sys | |
__all__ = ['exit_stack'] | |
class _ExitStackProxy: | |
def __init__(self): | |
self._stack = contextvars.ContextVar('exit_stack') | |
def get(self): | |
stack = self._stack.get(None) | |
if stack is None: | |
raise TypeError( | |
f"exit_stack is not enabled for this context. " | |
f"Didn't you forget to decorate it with @exit_stack.wrap?" | |
) | |
return stack | |
def __getattr__(self, name): | |
return getattr(self.get(), name) | |
def wrap(self, func): | |
"""Enable `exit_stack` in decorated function (as well as generator | |
function, coroutine function or async-generator function). Use it | |
to avoid excessive indentation of nested `with`/`async with` blocks | |
when you need a number of cleanup steps on return. | |
Usage example: | |
@exit_stack.wrap | |
def some_func(path, ...): | |
fp = exit_stack.enter_context( | |
pathlib.Path(path).open() | |
) | |
... | |
""" | |
if not inspect.isfunction(func): | |
raise TypeError(f'{func} is not a function') | |
if inspect.iscoroutinefunction(func): | |
async def wrapper(*args, **kwargs): | |
_exit_stack = AsyncExitStack() | |
self._stack.set(_exit_stack) | |
try: | |
return await func(*args, **kwargs) | |
finally: | |
await _exit_stack.aclose() | |
elif inspect.isasyncgenfunction(func): | |
async def wrapper(*args, **kwargs): | |
_exit_stack = AsyncExitStack() | |
self._stack.set(_exit_stack) | |
try: | |
gen = func(*args, **kwargs) | |
_exit_stack.push_async_callback(gen.aclose) | |
async for item in gen: | |
# XXX Propagate `asend()`/`athrow()`? | |
yield item | |
finally: | |
await _exit_stack.aclose() | |
elif inspect.isgeneratorfunction(func): | |
def wrapper(*args, **kwargs): | |
_exit_stack = ExitStack() | |
self._stack.set(_exit_stack) | |
try: | |
gen = func(*args, **kwargs) | |
_exit_stack.callback(gen.close) | |
yield from gen | |
finally: | |
_exit_stack.close() | |
else: | |
def wrapper(*args, **kwargs): | |
_exit_stack = ExitStack() | |
self._stack.set(_exit_stack) | |
try: | |
return func(*args, **kwargs) | |
finally: | |
_exit_stack.close() | |
return functools.wraps(func)(wrapper) | |
exit_stack = _ExitStackProxy() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment