Created
June 10, 2024 07:58
-
-
Save rmorshea/5254b4aa9d0f4749bb5d8dd9806b672b to your computer and use it in GitHub Desktop.
A simple, lightweight dependency injection utility for Python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import inspect | |
import sys | |
from contextlib import asynccontextmanager, contextmanager | |
from contextvars import ContextVar | |
from functools import wraps | |
from inspect import isasyncgenfunction, iscoroutinefunction, isfunction, isgeneratorfunction | |
from typing import ( | |
Annotated, | |
Any, | |
AsyncContextManager, | |
AsyncIterator, | |
Awaitable, | |
Callable, | |
ContextManager, | |
Iterator, | |
Mapping, | |
ParamSpec, | |
Sequence, | |
TypeAlias, | |
TypeVar, | |
cast, | |
get_args, | |
get_origin, | |
overload, | |
) | |
P = ParamSpec("P") | |
R = TypeVar("R") | |
SyncProvider: TypeAlias = Callable[[], ContextManager[R]] | |
AsyncProvider: TypeAlias = Callable[[], AsyncContextManager[R]] | |
Provider: TypeAlias = SyncProvider[R] | AsyncProvider[R] | |
@overload | |
def inject() -> Any: ... | |
@overload | |
def inject(func: Callable[P, R]) -> Callable[P, R]: ... | |
def inject(func: Callable[P, R] | None = None) -> Callable[P, R] | Any: | |
"""Inject values into a function.""" | |
if func is None: | |
return _INJECTION | |
dependencies = _get_context_vars_from_callable(func) | |
return _make_injection_wrapper(func, dependencies) | |
def provide(annotation: type[R]) -> Callable[[Provider[R]], Provider[R]]: | |
"""Register a provider constructor.""" | |
if not (var := _get_context_var_from_annotation(annotation)): | |
raise TypeError(f"Expected {annotation!r} to be annotated with a context var") | |
def decorator(provider: Provider[R]) -> Provider[R]: | |
provider = cast(Provider[R], inject(provider)) | |
wrapped_provider = _get_wrapped_function(provider) | |
dependencies = tuple(_get_context_vars_from_callable(wrapped_provider).values()) | |
if isinstance(wrapped_provider, type): | |
if issubclass(wrapped_provider, ContextManager): | |
sync_provider = cast(SyncProvider[R], provider) | |
uniform_provider = lambda: _SyncUniformContext(var, sync_provider, dependencies) | |
elif issubclass(wrapped_provider, AsyncContextManager): | |
async_provider = cast(AsyncProvider[R], provider) | |
uniform_provider = lambda: _AsyncUniformContext(var, async_provider, dependencies) | |
else: | |
raise TypeError(f"Unsupported provider type: {provider}") | |
elif isasyncgenfunction(wrapped_provider): | |
async_provider = asynccontextmanager(cast(Callable[[], AsyncIterator[R]], provider)) | |
uniform_provider = lambda: _AsyncUniformContext(var, async_provider, dependencies) | |
elif iscoroutinefunction(wrapped_provider): | |
async_func = cast(Callable[[], Awaitable[R]], provider) | |
async_provider = lambda: _AsyncFunctionManager(async_func) | |
uniform_provider = lambda: _AsyncUniformContext(var, async_provider, dependencies) | |
elif isgeneratorfunction(wrapped_provider): | |
sync_provider = contextmanager(cast(Callable[[], Iterator[R]], provider)) | |
uniform_provider = lambda: _SyncUniformContext(var, sync_provider, dependencies) | |
else: | |
sync_provider = lambda: _SyncFunctionManager(provider) | |
uniform_provider = lambda: _SyncUniformContext(var, sync_provider, dependencies) | |
_PROVIDERS_BY_VAR[var] = uniform_provider | |
return provider | |
return decorator | |
def _make_injection_wrapper( | |
func: Callable[P, R], | |
dependencies: Mapping[str, ContextVar], | |
) -> Callable[P, R]: | |
wrapper: Callable[..., Any] | |
if isasyncgenfunction(func): | |
async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: | |
contexts: list[_UniformContext] = [] | |
try: | |
for name, var in dependencies.items(): | |
if name in kwargs: | |
continue | |
context = _PROVIDERS_BY_VAR[var]() | |
kwargs[name] = await context.__aenter__() | |
contexts.append(context) | |
async for value in func(*args, **kwargs): | |
yield value | |
finally: | |
await _async_exhaust_exits(contexts) | |
wrapper = async_gen_wrapper | |
elif isgeneratorfunction(func): | |
def sync_gen_wrapper(*args: Any, **kwargs: Any) -> Any: | |
contexts: list[_UniformContext] = [] | |
try: | |
for name, var in dependencies.items(): | |
if name in kwargs: | |
continue | |
context = _PROVIDERS_BY_VAR[var]() | |
kwargs[name] = context.__enter__() | |
contexts.append(context) | |
yield from func(*args, **kwargs) | |
finally: | |
_exhaust_exits(contexts) | |
wrapper = sync_gen_wrapper | |
elif iscoroutinefunction(func): | |
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: | |
contexts: list[_UniformContext] = [] | |
try: | |
for name, var in dependencies.items(): | |
if name in kwargs: | |
continue | |
context = _PROVIDERS_BY_VAR[var]() | |
kwargs[name] = await context.__aenter__() | |
contexts.append(context) | |
return await func(*args, **kwargs) | |
finally: | |
await _async_exhaust_exits(contexts) | |
wrapper = async_wrapper | |
elif isfunction(func): | |
def sync_wrapper(*args: Any, **kwargs: Any) -> Any: | |
contexts: list[_UniformContext] = [] | |
try: | |
for name, var in dependencies.items(): | |
if name in kwargs: | |
continue | |
context = _PROVIDERS_BY_VAR[var]() | |
kwargs[name] = context.__enter__() | |
contexts.append(context) | |
return func(*args, **kwargs) | |
finally: | |
_exhaust_exits(contexts) | |
wrapper = sync_wrapper | |
else: | |
raise TypeError(f"Unsupported function type: {func}") | |
return cast(Callable[P, R], wraps(func)(wrapper)) | |
def _get_wrapped_function(func: Callable[P, R]) -> Callable[P, R]: | |
while maybe_func := getattr(func, "__wrapped__", None): | |
func = maybe_func | |
return func | |
def _get_context_vars_from_callable(func: Callable[..., Any]) -> Mapping[str, ContextVar]: | |
context_vars: dict[str, ContextVar] = {} | |
for param in inspect.signature(func).parameters.values(): | |
if param.default is _INJECTION: | |
anno = param.annotation | |
if isinstance(anno, str): | |
anno = eval(anno, func.__globals__) | |
if get_origin(anno) is not Annotated: | |
raise TypeError(f"Expected {param.name!r} to be annotated with a context var") | |
if var := _get_context_var_from_annotation(anno): | |
context_vars[param.name] = var | |
return context_vars | |
def _get_context_var_from_annotation(anno: Any) -> ContextVar | None: | |
if get_origin(anno) is not Annotated: | |
raise TypeError(f"Expected {anno!r} to be annotated with a context var") | |
_, *metadata = get_args(anno) | |
for meta in metadata: | |
if isinstance(meta, ContextVar): | |
return meta | |
class _UniformContext(ContextManager[R], AsyncContextManager[R]): ... | |
class _SyncUniformContext(_UniformContext[R]): | |
def __init__( | |
self, | |
var: ContextVar[R], | |
make_context: Callable[[], ContextManager[R]], | |
dependencies: Sequence[ContextVar], | |
): | |
self.var = var | |
self.make_context = make_context | |
self.token = None | |
self.dependencies = dependencies | |
self.dependency_contexts: list[_UniformContext] = [] | |
def __enter__(self) -> R: | |
try: | |
return self.var.get() | |
except LookupError: | |
for var in self.dependencies: | |
dependency_context = _PROVIDERS_BY_VAR[var]() | |
dependency_context.__enter__() | |
self.dependency_contexts.append(dependency_context) | |
self.context = context = self.make_context() | |
self.token = self.var.set(context.__enter__()) | |
return self.var.get() | |
def __exit__(self, etype: Any, eval: Any, atrace: Any) -> None: | |
if self.token is not None: | |
try: | |
self.var.reset(self.token) | |
finally: | |
try: | |
self.context.__exit__(etype, eval, atrace) | |
finally: | |
_exhaust_exits(self.dependency_contexts) | |
async def __aenter__(self) -> R: | |
return self.__enter__() | |
async def __aexit__(self, etype: Any, eval: Any, atrace: Any) -> None: | |
return self.__exit__(etype, eval, atrace) | |
class _AsyncUniformContext(_UniformContext[R]): | |
def __init__( | |
self, | |
var: ContextVar[R], | |
make_context: Callable[[], AsyncContextManager[R]], | |
dependencies: Sequence[ContextVar], | |
): | |
self.var = var | |
self.make_context = make_context | |
self.token = None | |
self.dependencies = dependencies | |
self.dependency_contexts: list[_UniformContext[Any]] = [] | |
def __enter__(self) -> R: | |
try: | |
return self.var.get() | |
except LookupError: | |
raise RuntimeError("Cannot use an async context manager in a sync context") | |
def __exit__(self, etype: Any, eval: Any, atrace: Any) -> None: | |
raise RuntimeError("Cannot use an async context manager in a sync context") | |
async def __aenter__(self) -> R: | |
try: | |
return self.var.get() | |
except LookupError: | |
for var in self.dependencies: | |
dependency_context = _PROVIDERS_BY_VAR[var]() | |
await dependency_context.__aenter__() | |
self.dependency_contexts.append(dependency_context) | |
self.context = context = self.make_context() | |
self.token = self.var.set(await context.__aenter__()) | |
return self.var.get() | |
async def __aexit__(self, etype: Any, eval: Any, atrace: Any) -> None: | |
if self.token is not None: | |
try: | |
self.var.reset(self.token) | |
finally: | |
try: | |
await self.context.__aexit__(etype, eval, atrace) | |
finally: | |
await _async_exhaust_exits(self.dependency_contexts) | |
class _AsyncFunctionManager(AsyncContextManager[R]): | |
def __init__(self, func: Callable[[], Awaitable[R]]) -> None: | |
self.func = func | |
async def __aenter__(self) -> R: | |
return await self.func() | |
async def __aexit__(self, etype: Any, eval: Any, atrace: Any) -> None: | |
pass | |
class _SyncFunctionManager(ContextManager[R]): | |
def __init__(self, func: Callable[[], R]) -> None: | |
self.func = func | |
def __enter__(self) -> R: | |
return self.func() | |
def __exit__(self, etype: Any, eval: Any, atrace: Any) -> None: | |
pass | |
def _exhaust_exits(ctxts: Sequence[ContextManager]) -> None: | |
if not ctxts: | |
return | |
try: | |
c, *ctxts = ctxts | |
c.__exit__(*sys.exc_info()) | |
except Exception: | |
_exhaust_exits(ctxts) | |
raise | |
else: | |
_exhaust_exits(ctxts) | |
async def _async_exhaust_exits(ctxts: Sequence[AsyncContextManager[Any]]) -> None: | |
if not ctxts: | |
return | |
try: | |
c, *ctxts = ctxts | |
await c.__aexit__(*sys.exc_info()) | |
except Exception: | |
await _async_exhaust_exits(ctxts) | |
raise | |
else: | |
await _async_exhaust_exits(ctxts) | |
_UniformProvider: TypeAlias = "Callable[[], _UniformContext[R]]" | |
_INJECTION = (type("INJECTION", (), {"__repr__": lambda _: "INJECTION"}))() | |
_PROVIDERS_BY_VAR: dict[ContextVar, _UniformProvider] = {} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment