Skip to content

Instantly share code, notes, and snippets.

@ods
Created October 15, 2020 08:56
Show Gist options
  • Save ods/2168ea823bd3075e418ca229c6283681 to your computer and use it in GitHub Desktop.
Save ods/2168ea823bd3075e418ca229c6283681 to your computer and use it in GitHub Desktop.
import asyncio
from contextlib import AsyncExitStack, ExitStack
import contextvars
import functools
import inspect
import sys
__all__ = ['exit_stack']
class _ExitStackProxy:
def __init__(self):
self._stack = contextvars.ContextVar('exit_stack')
def get(self):
stack = self._stack.get(None)
if stack is None:
raise TypeError(
f"exit_stack is not enabled for this context. "
f"Didn't you forget to decorate it with @exit_stack.wrap?"
)
return stack
def __getattr__(self, name):
return getattr(self.get(), name)
def wrap(self, func):
"""Enable `exit_stack` in decorated function (as well as generator
function, coroutine function or async-generator function). Use it
to avoid excessive indentation of nested `with`/`async with` blocks
when you need a number of cleanup steps on return.
Usage example:
@exit_stack.wrap
def some_func(path, ...):
fp = exit_stack.enter_context(
pathlib.Path(path).open()
)
...
"""
if not inspect.isfunction(func):
raise TypeError(f'{func} is not a function')
if inspect.iscoroutinefunction(func):
async def wrapper(*args, **kwargs):
_exit_stack = AsyncExitStack()
self._stack.set(_exit_stack)
try:
return await func(*args, **kwargs)
finally:
await _exit_stack.aclose()
elif inspect.isasyncgenfunction(func):
async def wrapper(*args, **kwargs):
_exit_stack = AsyncExitStack()
self._stack.set(_exit_stack)
try:
gen = func(*args, **kwargs)
_exit_stack.push_async_callback(gen.aclose)
async for item in gen:
# XXX Propagate `asend()`/`athrow()`?
yield item
finally:
await _exit_stack.aclose()
elif inspect.isgeneratorfunction(func):
def wrapper(*args, **kwargs):
_exit_stack = ExitStack()
self._stack.set(_exit_stack)
try:
gen = func(*args, **kwargs)
_exit_stack.callback(gen.close)
yield from gen
finally:
_exit_stack.close()
else:
def wrapper(*args, **kwargs):
_exit_stack = ExitStack()
self._stack.set(_exit_stack)
try:
return func(*args, **kwargs)
finally:
_exit_stack.close()
return functools.wraps(func)(wrapper)
exit_stack = _ExitStackProxy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment