Created
June 13, 2019 03:11
-
-
Save jsbueno/32beb7e087176089b8c33b479cb22129 to your computer and use it in GitHub Desktop.
Snippet with metaclass enabling a class to have an async __init__ method in Python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from functools import wraps | |
def async_init_wrapper(func): | |
def wrapper_stage1(instance): | |
async def wrapper_stage2(*args, **kw): | |
value = await func(instance, *args, **kw) | |
if value is not None: | |
raise TypeError("__async_init__() should return None") | |
return instance | |
return wrapper_stage2 | |
wrapper_stage1.__name__ = func.__name__ | |
return wrapper_stage1 | |
class AwaitableClass(type): | |
def __new__(mcls, name, bases, ns, **kw): | |
if "__init__" in ns and inspect.iscoroutinefunction(ns["__init__"]): | |
ns["__async_init__"] = async_init_wrapper(ns.pop("__init__")) | |
return super().__new__(mcls, name, bases, ns, **kw) | |
def __call__(cls, *args, **kw): | |
instance = super().__call__(*args, **kw) | |
if not isinstance(instance, cls) or not hasattr(cls, "__async_init__"): | |
return instance | |
return instance.__async_init__()(*args, **kw) | |
def test_awaitable_class(): | |
import asyncio | |
class Server(metaclass=AwaitableClass): | |
async def __init__(self): | |
self.connection = await self.connect() | |
async def connect(self): | |
await asyncio.sleep(1) | |
return 'server initialized connection' | |
async def concurrent_task(): | |
await asyncio.sleep(0.5) | |
print("doing stuff while server is initialized") | |
async def init_server(): | |
print("starting server initialization") | |
server_instance = await Server() | |
print("server ready") | |
return server_instance | |
async def main(): | |
results = await asyncio.gather( | |
init_server(), | |
concurrent_task() | |
) | |
return results[0] | |
loop = asyncio.get_event_loop() | |
server_instance = loop.run_until_complete(main()) | |
print(server_instance.connection) | |
if __name__ == "__main__": | |
test_awaitable_class() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment