Last active
September 25, 2024 14:22
-
-
Save jgadling/1971426b0075073ea6d13d64cade1310 to your computer and use it in GitHub Desktop.
Example of using Strawberry FieldExtensions to implement FastAPI dependency injection in resolvers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This snippet is valid for older FastAPI versions. | |
# For an updated DependencyExtension that works with FastAPI 0.115.0, refer to this gist: | |
# https://gist.github.com/jgadling/bf27a924cd9c34a2a64d2dbf8a5507e6 | |
import typing | |
import strawberry | |
import uvicorn | |
from fastapi import Depends, FastAPI | |
from fastapi.dependencies import utils as deputils | |
from fastapi.params import Depends as DependsClass | |
from strawberry.extensions import FieldExtension | |
from strawberry.fastapi import GraphQLRouter | |
from strawberry.field import StrawberryField | |
from strawberry.types import Info | |
async def dependency1() -> str: | |
print("dependency1") | |
return "dep1" | |
def dependency2(dep1=Depends(dependency1)) -> str: | |
print("dependency2") | |
return f"dep2 (sub: {dep1})" | |
def dependency3(dep2=Depends(dependency2)) -> str: | |
print(f"dependency3 (sub {dep2})") | |
return f"dep3" | |
class DependencyExtension(FieldExtension): | |
def __init__(self): | |
self.dependency_args: list[typing.Any] = [] | |
self.dependency_overrides_provider = None | |
self.strawberry_field_names = ["self"] | |
def apply(self, field: StrawberryField) -> None: | |
self.dependant = deputils.get_dependant( | |
path="/", call=field.base_resolver.wrapped_func | |
) | |
# Remove fastapi Depends arguments from the list that strawberry tries to resolve | |
field.arguments = [ | |
item | |
for item in field.arguments | |
if not isinstance(item.default, DependsClass) | |
] | |
# Stash the list of strawberry argument names so we can safely ignore them later | |
self.strawberry_field_names += [item.python_name for item in field.arguments] | |
async def resolve_async( | |
self, | |
next_: typing.Callable[..., typing.Any], | |
source: typing.Any, | |
info: Info, | |
**kwargs, | |
) -> typing.Any: | |
request = info.context["request"] | |
try: | |
if "dependency_cache" not in request.context: | |
request.context["dependency_cache"] = {} | |
except AttributeError: | |
request.context = {"dependency_cache": {}} | |
solved_result = await deputils.solve_dependencies( | |
request=request, | |
dependant=self.dependant, | |
body={}, | |
dependency_overrides_provider=request.app, | |
dependency_cache=request.context["dependency_cache"], | |
) | |
( | |
solved_values, | |
_, # solver_errors. It shouldn't be possible for it to contain | |
# anything relevant to this extension. | |
_, # background tasks | |
_, # the subdependency returns the same response we have | |
new_cache, # sub_dependency_cache | |
) = solved_result | |
request.context["dependency_cache"].update(new_cache) | |
kwargs = solved_values | kwargs # solved_values has None values that need to be overridden by kwargs | |
res = await next_(source, info, **kwargs) | |
return res | |
@strawberry.type | |
class ChildModel: | |
name: str | |
value: str | |
@strawberry.type | |
class ParentModel: | |
name: str | |
@strawberry.field(extensions=[DependencyExtension()]) | |
async def related( | |
self, | |
dep1: str = Depends(dependency1), | |
dep2: str = Depends(dependency2), | |
) -> ChildModel: | |
return ChildModel(name=dep1, value=dep2) | |
@strawberry.type | |
class Query: | |
@strawberry.field(extensions=[DependencyExtension()]) | |
def get_model( | |
self, strawberry_field: str, name: str = Depends(dependency3) | |
) -> ParentModel: | |
return ParentModel(name=f"{strawberry_field} {name}") | |
def get_context(): | |
return { | |
"this_still_works": {}, | |
} | |
schema = strawberry.Schema(query=Query) | |
graphql_app = GraphQLRouter(schema, context_getter=get_context, graphiql=True) | |
app = FastAPI() | |
app.include_router(graphql_app, prefix="/graphql") | |
if __name__ == "__main__": | |
config = uvicorn.Config("main:app", host="0.0.0.0", port=8008, log_level="info") | |
server = uvicorn.Server(config) | |
server.run() |
hi @jgadling
Thank you for this awesome extension.
But, i have problem. could you help me?
For example,
from strawberry.types import Info
def resolver_to_update_contract(
self,
customer_id: str,
api_quota: int,
api_plan_id: int,
ended_on: datetime,
info: Info,
) -> UpdateContractResponse:
...
if i have strawberry.types.info, your dependency extension is not working. how can i fix it ?
Thanks for Contribution.
Hie @jgadling
Thank yu for this awesome extension. I was wndering if there is a way for making it global like this schema = strawberry.Schema(query=Query, extensions=[DependencyExtension()])
Thanks
I wouldn't make it global, but on a field basis.
Consider defining a decorator
def add_dependency_extension(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if 'extensions' in kwargs and isinstance(kwargs['extensions'], list):
kwargs['extensions'].append(DependencyExtension())
else:
kwargs['extensions'] = [DependencyExtension()]
return func(*args, **kwargs)
return wrapper
and a wrapper
@add_dependency_extension
def dependencies_field(*args, **kwargs):
return strawberry.field(*args, **kwargs)
and then
@type
class Query:
books: typing.List[Book] = dependencies_field(resolver=get_books)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hie @jgadling
Thank yu for this awesome extension. I was wndering if there is a way for making it global like this
schema = strawberry.Schema(query=Query, extensions=[DependencyExtension()])
Thanks