Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simonw/496b24fdad44f6f8b7237fe394a0ced7 to your computer and use it in GitHub Desktop.
Save simonw/496b24fdad44f6f8b7237fe394a0ced7 to your computer and use it in GitHub Desktop.
Dependency injection for asyncio concurrency
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "voluntary-album",
"metadata": {},
"source": [
"# async-di: Dependency injection for asyncio concurrency\n",
"\n",
"Exploring the concept of using pytest-style dependency injection to schedule asyncio concurrent tasks.\n",
"\n",
"Key idea: if a function takes named arguments, those are matched against other functions which are then awaited in parallel before the results are passed to that original function.\n",
"\n",
"I'll try it with a `@di` decorator first."
]
},
{
"cell_type": "code",
"execution_count": 112,
"id": "concerned-crisis",
"metadata": {},
"outputs": [],
"source": [
"import inspect\n",
"import asyncio\n",
"from functools import wraps\n",
"\n",
"\n",
"registry = {}\n",
"\n",
"\n",
"def di(fn):\n",
" fn.run = make_runner(fn)\n",
" registry[fn.__name__] = fn\n",
" return fn"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "excess-acrobat",
"metadata": {},
"outputs": [],
"source": [
"def make_runner(fn):\n",
" @wraps(fn)\n",
" async def inner(**kwargs):\n",
" parameters = inspect.signature(fn).parameters.keys()\n",
" # Any parameters not provided by kwargs are resolved from registry\n",
" to_resolve = [p for p in parameters if p not in kwargs]\n",
" missing = [p for p in to_resolve if p not in registry]\n",
" assert not missing, 'The following DI parameters could not be found in the registry: {}'.format(missing)\n",
" \n",
" awaitables = [registry[name].run() for name in to_resolve]\n",
" \n",
" # Assert that all missing params are awaitable\n",
" assert all(asyncio.iscoroutine(p) for p in awaitables)\n",
" results = {}\n",
" results.update(kwargs)\n",
" awaitable_results = await asyncio.gather(*awaitables)\n",
" results.update((p[0].__name__, p[1]) for p in zip(awaitables, awaitable_results))\n",
" \n",
" return await fn(**results)\n",
" \n",
" return inner"
]
},
{
"cell_type": "code",
"execution_count": 114,
"id": "lovely-generic",
"metadata": {},
"outputs": [],
"source": [
"@di\n",
"async def wait_two_seconds():\n",
" print(\"about to sleep 2\")\n",
" await asyncio.sleep(2)\n",
" print(\"Done 2\")\n",
" return 2"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "separate-julian",
"metadata": {},
"outputs": [],
"source": [
"@di\n",
"async def wait_three_seconds():\n",
" print(\"about to sleep 3\")\n",
" await asyncio.sleep(3)\n",
" print(\"Done 3\")\n",
" return 3"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "after-happening",
"metadata": {},
"outputs": [],
"source": [
"@di\n",
"async def wait_two_point_five_seconds():\n",
" print(\"about to sleep 2.5\")\n",
" await asyncio.sleep(2.5)\n",
" print(\"Done 2.5\")\n",
" return 2.5"
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "separate-national",
"metadata": {},
"outputs": [],
"source": [
"@di\n",
"async def do_both(wait_two_seconds, wait_three_seconds, wait_two_point_five_seconds, num):\n",
" print(wait_two_seconds + wait_three_seconds + num)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"id": "patent-detector",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"about to sleep 2\n",
"about to sleep 3\n",
"about to sleep 2.5\n",
"Done 2\n",
"Done 2.5\n",
"Done 3\n",
"9\n"
]
}
],
"source": [
"await do_both.run(num=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "passive-brooks",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Here's the class version

class AsyncDI(type):
    def __new__(cls, name, bases, attrs):
        # Decorate any items that are 'async def' methods
        registry = {}
        new_attrs = {}
        for key, value in attrs.items():
            if inspect.iscoroutinefunction(value):
                new_attrs[key] = make_method(value, registry)
                registry[key] = new_attrs[key]
            else:
                new_attrs[key] = value
        return super().__new__(cls, name, bases, new_attrs)


def make_method(method, registry):
    @wraps(method)
    async def inner(self, **kwargs):
        parameters = inspect.signature(method).parameters.keys()
        # Any parameters not provided by kwargs are resolved from registry
        to_resolve = [p for p in parameters if p not in kwargs and p != "self"]
        missing = [p for p in to_resolve if p not in registry]
        assert (
            not missing
        ), "The following DI parameters could not be found in the registry: {}".format(
            missing
        )

        awaitables = [registry[name](self) for name in to_resolve]

        # Assert that all missing params are awaitable
        assert all(asyncio.iscoroutine(p) for p in awaitables)
        results = {}
        results.update(kwargs)
        awaitable_results = await asyncio.gather(*awaitables)
        results.update(
            (p[0].__name__, p[1]) for p in zip(awaitables, awaitable_results)
        )

        return await method(self, **results)

    return inner

Example usage:

class Foo(metaclass=AsyncDI):
    async def other(self):
        return 5
    async def async_blah(self, other):
        return 1 + other

f = Foo()
await f.async_blah()
# Outputs 6
@simonw
Copy link
Author

simonw commented Nov 16, 2021

New version:

import inspect

class AsyncMeta(type):
    def __new__(cls, name, bases, attrs):
        # Decorate any items that are 'async def' methods
        _registry = {}
        new_attrs = {"_registry": _registry}
        for key, value in attrs.items():
            if inspect.iscoroutinefunction(value) and not value.__name__ == "resolve":
                new_attrs[key] = make_method(value)
                _registry[key] = new_attrs[key]
            else:
                new_attrs[key] = value
        return super().__new__(cls, name, bases, new_attrs)


def make_method(method):
    @wraps(method)
    async def inner(self, **kwargs):
        parameters = inspect.signature(method).parameters.keys()
        # Any parameters not provided by kwargs are resolved from registry
        to_resolve = [p for p in parameters if p not in kwargs and p != "self"]
        missing = [p for p in to_resolve if p not in self._registry]
        assert (
            not missing
        ), "The following DI parameters could not be found in the registry: {}".format(
            missing
        )
        results = {}
        results.update(kwargs)
        results.update(await self.resolve(to_resolve))
        return await method(self, **results)

    return inner


class AsyncBase(metaclass=AsyncMeta):
    async def resolve(self, names):
        awaitables = [self._registry[name](self) for name in names]
        # Assert that all missing params are awaitable
        assert all(asyncio.iscoroutine(p) for p in awaitables)
        awaitable_results = await asyncio.gather(*awaitables)
        results = {
            p[0].__name__: p[1] for p in zip(awaitables, awaitable_results)
        }
        return results

Use like this:

class Foo(AsyncBase):
    async def graa(self, boff):
        print("graa")
        return 5
    async def boff(self):
        print("boff")
        return 8
    async def other(self, boff, graa):
        print("other")
        return 5 + boff + graa
>>> await Foo().other()
boff
boff
graa
other
18
>>> await Foo().resolve(["boff", "graa"])
boff
boff
graa
{'boff': 8, 'graa': 5}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment