Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simonw/496b24fdad44f6f8b7237fe394a0ced7 to your computer and use it in GitHub Desktop.
Save simonw/496b24fdad44f6f8b7237fe394a0ced7 to your computer and use it in GitHub Desktop.
Dependency injection for asyncio concurrency
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Here's the class version

class AsyncDI(type):
    def __new__(cls, name, bases, attrs):
        # Decorate any items that are 'async def' methods
        registry = {}
        new_attrs = {}
        for key, value in attrs.items():
            if inspect.iscoroutinefunction(value):
                new_attrs[key] = make_method(value, registry)
                registry[key] = new_attrs[key]
            else:
                new_attrs[key] = value
        return super().__new__(cls, name, bases, new_attrs)


def make_method(method, registry):
    @wraps(method)
    async def inner(self, **kwargs):
        parameters = inspect.signature(method).parameters.keys()
        # Any parameters not provided by kwargs are resolved from registry
        to_resolve = [p for p in parameters if p not in kwargs and p != "self"]
        missing = [p for p in to_resolve if p not in registry]
        assert (
            not missing
        ), "The following DI parameters could not be found in the registry: {}".format(
            missing
        )

        awaitables = [registry[name](self) for name in to_resolve]

        # Assert that all missing params are awaitable
        assert all(asyncio.iscoroutine(p) for p in awaitables)
        results = {}
        results.update(kwargs)
        awaitable_results = await asyncio.gather(*awaitables)
        results.update(
            (p[0].__name__, p[1]) for p in zip(awaitables, awaitable_results)
        )

        return await method(self, **results)

    return inner

Example usage:

class Foo(metaclass=AsyncDI):
    async def other(self):
        return 5
    async def async_blah(self, other):
        return 1 + other

f = Foo()
await f.async_blah()
# Outputs 6
@simonw
Copy link
Author

simonw commented Nov 16, 2021

New version:

import inspect

class AsyncMeta(type):
    def __new__(cls, name, bases, attrs):
        # Decorate any items that are 'async def' methods
        _registry = {}
        new_attrs = {"_registry": _registry}
        for key, value in attrs.items():
            if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve":
                new_attrs[key] = make_method(value)
                _registry[key] = new_attrs[key]
            else:
                new_attrs[key] = value
        return super().__new__(cls, name, bases, new_attrs)


def make_method(method):
    @wraps(method)
    async def inner(self, **kwargs):
        parameters = inspect.signature(method).parameters.keys()
        # Any parameters not provided by kwargs are resolved from registry
        to_resolve = [p for p in parameters if p not in kwargs and p != "self"]
        missing = [p for p in to_resolve if p not in self._registry]
        assert (
            not missing
        ), "The following DI parameters could not be found in the registry: {}".format(
            missing
        )
        results = {}
        results.update(kwargs)
        results.update(await self.resolve(to_resolve))
        return await method(self, **results)

    return inner


class AsyncBase(metaclass=AsyncMeta):
    async def resolve(self, names):
        awaitables = [self._registry[name](self) for name in names]
        # Assert that all missing params are awaitable
        assert all(asyncio.iscoroutine(p) for p in awaitables)
        awaitable_results = await asyncio.gather(*awaitables)
        results = {
            p[0].__name__: p[1] for p in zip(awaitables, awaitable_results)
        }
        return results

Use like this:

class Foo(AsyncBase):
    async def graa(self, boff):
        print("graa")
        return 5
    async def boff(self):
        print("boff")
        return 8
    async def other(self, boff, graa):
        print("other")
        return 5 + boff + graa
>>> await Foo().other()
boff
boff
graa
other
18
>>> await Foo().resolve(["boff", "graa"])
boff
boff
graa
{'boff': 8, 'graa': 5}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment