Created
January 3, 2023 16:08
-
-
Save rodion-solovev-7/7521977f17221f40248728cac41582a0 to your computer and use it in GitHub Desktop.
Упоротый способ внедрять fastapi.Depends туда, где он не нужен
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import Depends | |
from fastapi_depends import inject_dependencies | |
def get_my_dep1() -> str: | |
# any type obj may be here. sqlalchemy.session and etc | |
yield 'HELLO WORLD' | |
@inject_dependencies() | |
async def example(inp_line: str = Depends(get_my_dep1)): | |
pritn(inp_line) | |
if __name__ == '__main__': | |
asyncio.run(example()) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""From FastAPI 0.88.0""" | |
import asyncio | |
import functools | |
from contextlib import AsyncExitStack | |
from typing import * | |
from fastapi.dependencies.models import Dependant | |
from fastapi.dependencies.utils import ( | |
get_dependant, is_gen_callable, is_async_gen_callable, | |
solve_generator, is_coroutine_callable, | |
) | |
from pydantic.error_wrappers import ErrorWrapper | |
from starlette.background import BackgroundTasks | |
from starlette.concurrency import run_in_threadpool | |
async def solve_dependencies( | |
*, | |
dependant: Dependant, | |
background_tasks: BackgroundTasks | None = None, | |
dependency_overrides_provider: Any | None = None, | |
dependency_cache: dict[tuple[Callable[..., Any], tuple[str]], Any] | None = None, | |
stack: AsyncExitStack, | |
) -> tuple[ | |
dict[str, Any], | |
list[ErrorWrapper], | |
BackgroundTasks | None, | |
dict[tuple[Callable[..., Any], tuple[str]], Any], | |
]: | |
values: dict[str, Any] = {} | |
errors: list[ErrorWrapper] = [] | |
dependency_cache = dependency_cache or {} | |
sub_dependant: Dependant | |
for sub_dependant in dependant.dependencies: | |
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) | |
sub_dependant.cache_key = cast( | |
tuple[Callable[..., Any], tuple[str]], sub_dependant.cache_key | |
) | |
call = sub_dependant.call | |
use_sub_dependant = sub_dependant | |
if ( | |
dependency_overrides_provider | |
and dependency_overrides_provider.dependency_overrides | |
): | |
original_call = sub_dependant.call | |
call = getattr( | |
dependency_overrides_provider, "dependency_overrides", {} | |
).get(original_call, original_call) | |
use_path: str = sub_dependant.path # type: ignore | |
use_sub_dependant = get_dependant( | |
path=use_path, | |
call=call, | |
name=sub_dependant.name, | |
security_scopes=sub_dependant.security_scopes, | |
) | |
solved_result = await solve_dependencies( | |
dependant=use_sub_dependant, | |
background_tasks=background_tasks, | |
dependency_overrides_provider=dependency_overrides_provider, | |
dependency_cache=dependency_cache, | |
stack=stack, | |
) | |
sub_values, sub_errors, background_tasks, sub_dependency_cache = solved_result | |
dependency_cache.update(sub_dependency_cache) | |
if sub_errors: | |
errors.extend(sub_errors) | |
continue | |
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: | |
solved = dependency_cache[sub_dependant.cache_key] | |
elif is_gen_callable(call) or is_async_gen_callable(call): | |
assert isinstance(stack, AsyncExitStack) | |
solved = await solve_generator( | |
call=call, stack=stack, sub_values=sub_values | |
) | |
elif is_coroutine_callable(call): | |
solved = await call(**sub_values) | |
else: | |
solved = await run_in_threadpool(call, **sub_values) | |
if sub_dependant.name is not None: | |
values[sub_dependant.name] = solved | |
if sub_dependant.cache_key not in dependency_cache: | |
dependency_cache[sub_dependant.cache_key] = solved | |
if dependant.background_tasks_param_name: | |
if background_tasks is None: | |
background_tasks = BackgroundTasks() | |
values[dependant.background_tasks_param_name] = background_tasks | |
return values, errors, background_tasks, dependency_cache | |
async def run_task_function( | |
*, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool | |
) -> Any: | |
# Only called by get_request_handler. Has been split into its own function to | |
# facilitate profiling endpoints, since inner functions are harder to profile. | |
assert dependant.call is not None, "dependant.call must be a function" | |
if is_coroutine: | |
return await dependant.call(**values) | |
else: | |
return await run_in_threadpool(dependant.call, **values) | |
def inject_dependencies(use_cache: bool = True): | |
"""Декоратор, позволяющий внедрять зависимости через Depends. | |
Вне FastAPI endpoint-обработчиков. Например, для периодических или фоновых задач. | |
""" | |
def wrapper(call: Callable[..., Any]): | |
dependant = get_dependant( | |
path=call.__name__, | |
call=call, | |
name=call.__name__, | |
security_scopes=None, | |
use_cache=use_cache, | |
) | |
is_coroutine = asyncio.iscoroutinefunction(dependant.call) | |
dependency_overrides_provider = None | |
@functools.wraps(call) | |
async def wrapped() -> Any: | |
background_tasks = BackgroundTasks() | |
dependency_cache = {} | |
dependency_exception = None | |
execution_exception = None | |
try: | |
async with AsyncExitStack() as stack: | |
try: | |
solved_result = await solve_dependencies( | |
dependant=dependant, | |
background_tasks=background_tasks, | |
dependency_overrides_provider=dependency_overrides_provider, | |
dependency_cache=dependency_cache, | |
stack=stack, | |
) | |
values, errors, background_tasks, _ = solved_result | |
except Exception as e: | |
dependency_exception = e | |
raise e | |
try: | |
result = await run_task_function( | |
dependant=dependant, values=values, is_coroutine=is_coroutine | |
) | |
except Exception as e: | |
execution_exception = e | |
raise e | |
except Exception: | |
if dependency_exception: | |
raise dependency_exception | |
if execution_exception: | |
raise execution_exception | |
return result | |
return wrapped | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment