Skip to content

Instantly share code, notes, and snippets.

@rmorshea
Created June 10, 2024 07:58
Show Gist options
  • Save rmorshea/5254b4aa9d0f4749bb5d8dd9806b672b to your computer and use it in GitHub Desktop.
Save rmorshea/5254b4aa9d0f4749bb5d8dd9806b672b to your computer and use it in GitHub Desktop.
A simple, lightweight dependency injection utility for Python.
from __future__ import annotations
import inspect
import sys
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from functools import wraps
from inspect import isasyncgenfunction, iscoroutinefunction, isfunction, isgeneratorfunction
from typing import (
Annotated,
Any,
AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
ContextManager,
Iterator,
Mapping,
ParamSpec,
Sequence,
TypeAlias,
TypeVar,
cast,
get_args,
get_origin,
overload,
)
P = ParamSpec("P")
R = TypeVar("R")
SyncProvider: TypeAlias = Callable[[], ContextManager[R]]
AsyncProvider: TypeAlias = Callable[[], AsyncContextManager[R]]
Provider: TypeAlias = SyncProvider[R] | AsyncProvider[R]
@overload
def inject() -> Any: ...
@overload
def inject(func: Callable[P, R]) -> Callable[P, R]: ...
def inject(func: Callable[P, R] | None = None) -> Callable[P, R] | Any:
"""Inject values into a function."""
if func is None:
return _INJECTION
dependencies = _get_context_vars_from_callable(func)
return _make_injection_wrapper(func, dependencies)
def provide(annotation: type[R]) -> Callable[[Provider[R]], Provider[R]]:
"""Register a provider constructor."""
if not (var := _get_context_var_from_annotation(annotation)):
raise TypeError(f"Expected {annotation!r} to be annotated with a context var")
def decorator(provider: Provider[R]) -> Provider[R]:
provider = cast(Provider[R], inject(provider))
wrapped_provider = _get_wrapped_function(provider)
dependencies = tuple(_get_context_vars_from_callable(wrapped_provider).values())
if isinstance(wrapped_provider, type):
if issubclass(wrapped_provider, ContextManager):
sync_provider = cast(SyncProvider[R], provider)
uniform_provider = lambda: _SyncUniformContext(var, sync_provider, dependencies)
elif issubclass(wrapped_provider, AsyncContextManager):
async_provider = cast(AsyncProvider[R], provider)
uniform_provider = lambda: _AsyncUniformContext(var, async_provider, dependencies)
else:
raise TypeError(f"Unsupported provider type: {provider}")
elif isasyncgenfunction(wrapped_provider):
async_provider = asynccontextmanager(cast(Callable[[], AsyncIterator[R]], provider))
uniform_provider = lambda: _AsyncUniformContext(var, async_provider, dependencies)
elif iscoroutinefunction(wrapped_provider):
async_func = cast(Callable[[], Awaitable[R]], provider)
async_provider = lambda: _AsyncFunctionManager(async_func)
uniform_provider = lambda: _AsyncUniformContext(var, async_provider, dependencies)
elif isgeneratorfunction(wrapped_provider):
sync_provider = contextmanager(cast(Callable[[], Iterator[R]], provider))
uniform_provider = lambda: _SyncUniformContext(var, sync_provider, dependencies)
else:
sync_provider = lambda: _SyncFunctionManager(provider)
uniform_provider = lambda: _SyncUniformContext(var, sync_provider, dependencies)
_PROVIDERS_BY_VAR[var] = uniform_provider
return provider
return decorator
def _make_injection_wrapper(
func: Callable[P, R],
dependencies: Mapping[str, ContextVar],
) -> Callable[P, R]:
wrapper: Callable[..., Any]
if isasyncgenfunction(func):
async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
contexts: list[_UniformContext] = []
try:
for name, var in dependencies.items():
if name in kwargs:
continue
context = _PROVIDERS_BY_VAR[var]()
kwargs[name] = await context.__aenter__()
contexts.append(context)
async for value in func(*args, **kwargs):
yield value
finally:
await _async_exhaust_exits(contexts)
wrapper = async_gen_wrapper
elif isgeneratorfunction(func):
def sync_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
contexts: list[_UniformContext] = []
try:
for name, var in dependencies.items():
if name in kwargs:
continue
context = _PROVIDERS_BY_VAR[var]()
kwargs[name] = context.__enter__()
contexts.append(context)
yield from func(*args, **kwargs)
finally:
_exhaust_exits(contexts)
wrapper = sync_gen_wrapper
elif iscoroutinefunction(func):
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
contexts: list[_UniformContext] = []
try:
for name, var in dependencies.items():
if name in kwargs:
continue
context = _PROVIDERS_BY_VAR[var]()
kwargs[name] = await context.__aenter__()
contexts.append(context)
return await func(*args, **kwargs)
finally:
await _async_exhaust_exits(contexts)
wrapper = async_wrapper
elif isfunction(func):
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
contexts: list[_UniformContext] = []
try:
for name, var in dependencies.items():
if name in kwargs:
continue
context = _PROVIDERS_BY_VAR[var]()
kwargs[name] = context.__enter__()
contexts.append(context)
return func(*args, **kwargs)
finally:
_exhaust_exits(contexts)
wrapper = sync_wrapper
else:
raise TypeError(f"Unsupported function type: {func}")
return cast(Callable[P, R], wraps(func)(wrapper))
def _get_wrapped_function(func: Callable[P, R]) -> Callable[P, R]:
while maybe_func := getattr(func, "__wrapped__", None):
func = maybe_func
return func
def _get_context_vars_from_callable(func: Callable[..., Any]) -> Mapping[str, ContextVar]:
context_vars: dict[str, ContextVar] = {}
for param in inspect.signature(func).parameters.values():
if param.default is _INJECTION:
anno = param.annotation
if isinstance(anno, str):
anno = eval(anno, func.__globals__)
if get_origin(anno) is not Annotated:
raise TypeError(f"Expected {param.name!r} to be annotated with a context var")
if var := _get_context_var_from_annotation(anno):
context_vars[param.name] = var
return context_vars
def _get_context_var_from_annotation(anno: Any) -> ContextVar | None:
if get_origin(anno) is not Annotated:
raise TypeError(f"Expected {anno!r} to be annotated with a context var")
_, *metadata = get_args(anno)
for meta in metadata:
if isinstance(meta, ContextVar):
return meta
class _UniformContext(ContextManager[R], AsyncContextManager[R]): ...
class _SyncUniformContext(_UniformContext[R]):
def __init__(
self,
var: ContextVar[R],
make_context: Callable[[], ContextManager[R]],
dependencies: Sequence[ContextVar],
):
self.var = var
self.make_context = make_context
self.token = None
self.dependencies = dependencies
self.dependency_contexts: list[_UniformContext] = []
def __enter__(self) -> R:
try:
return self.var.get()
except LookupError:
for var in self.dependencies:
dependency_context = _PROVIDERS_BY_VAR[var]()
dependency_context.__enter__()
self.dependency_contexts.append(dependency_context)
self.context = context = self.make_context()
self.token = self.var.set(context.__enter__())
return self.var.get()
def __exit__(self, etype: Any, eval: Any, atrace: Any) -> None:
if self.token is not None:
try:
self.var.reset(self.token)
finally:
try:
self.context.__exit__(etype, eval, atrace)
finally:
_exhaust_exits(self.dependency_contexts)
async def __aenter__(self) -> R:
return self.__enter__()
async def __aexit__(self, etype: Any, eval: Any, atrace: Any) -> None:
return self.__exit__(etype, eval, atrace)
class _AsyncUniformContext(_UniformContext[R]):
def __init__(
self,
var: ContextVar[R],
make_context: Callable[[], AsyncContextManager[R]],
dependencies: Sequence[ContextVar],
):
self.var = var
self.make_context = make_context
self.token = None
self.dependencies = dependencies
self.dependency_contexts: list[_UniformContext[Any]] = []
def __enter__(self) -> R:
try:
return self.var.get()
except LookupError:
raise RuntimeError("Cannot use an async context manager in a sync context")
def __exit__(self, etype: Any, eval: Any, atrace: Any) -> None:
raise RuntimeError("Cannot use an async context manager in a sync context")
async def __aenter__(self) -> R:
try:
return self.var.get()
except LookupError:
for var in self.dependencies:
dependency_context = _PROVIDERS_BY_VAR[var]()
await dependency_context.__aenter__()
self.dependency_contexts.append(dependency_context)
self.context = context = self.make_context()
self.token = self.var.set(await context.__aenter__())
return self.var.get()
async def __aexit__(self, etype: Any, eval: Any, atrace: Any) -> None:
if self.token is not None:
try:
self.var.reset(self.token)
finally:
try:
await self.context.__aexit__(etype, eval, atrace)
finally:
await _async_exhaust_exits(self.dependency_contexts)
class _AsyncFunctionManager(AsyncContextManager[R]):
def __init__(self, func: Callable[[], Awaitable[R]]) -> None:
self.func = func
async def __aenter__(self) -> R:
return await self.func()
async def __aexit__(self, etype: Any, eval: Any, atrace: Any) -> None:
pass
class _SyncFunctionManager(ContextManager[R]):
def __init__(self, func: Callable[[], R]) -> None:
self.func = func
def __enter__(self) -> R:
return self.func()
def __exit__(self, etype: Any, eval: Any, atrace: Any) -> None:
pass
def _exhaust_exits(ctxts: Sequence[ContextManager]) -> None:
if not ctxts:
return
try:
c, *ctxts = ctxts
c.__exit__(*sys.exc_info())
except Exception:
_exhaust_exits(ctxts)
raise
else:
_exhaust_exits(ctxts)
async def _async_exhaust_exits(ctxts: Sequence[AsyncContextManager[Any]]) -> None:
if not ctxts:
return
try:
c, *ctxts = ctxts
await c.__aexit__(*sys.exc_info())
except Exception:
await _async_exhaust_exits(ctxts)
raise
else:
await _async_exhaust_exits(ctxts)
_UniformProvider: TypeAlias = "Callable[[], _UniformContext[R]]"
_INJECTION = (type("INJECTION", (), {"__repr__": lambda _: "INJECTION"}))()
_PROVIDERS_BY_VAR: dict[ContextVar, _UniformProvider] = {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment