Skip to content

Instantly share code, notes, and snippets.

@jgadling
Last active September 25, 2024 14:22
Show Gist options
  • Save jgadling/1971426b0075073ea6d13d64cade1310 to your computer and use it in GitHub Desktop.
Save jgadling/1971426b0075073ea6d13d64cade1310 to your computer and use it in GitHub Desktop.
Example of using Strawberry FieldExtensions to implement FastAPI dependency injection in resolvers
# This snippet is valid for older FastAPI versions.
# For an updated DependencyExtension that works with FastAPI 0.115.0, refer to this gist:
# https://gist.github.com/jgadling/bf27a924cd9c34a2a64d2dbf8a5507e6
import typing
import strawberry
import uvicorn
from fastapi import Depends, FastAPI
from fastapi.dependencies import utils as deputils
from fastapi.params import Depends as DependsClass
from strawberry.extensions import FieldExtension
from strawberry.fastapi import GraphQLRouter
from strawberry.field import StrawberryField
from strawberry.types import Info
async def dependency1() -> str:
print("dependency1")
return "dep1"
def dependency2(dep1=Depends(dependency1)) -> str:
print("dependency2")
return f"dep2 (sub: {dep1})"
def dependency3(dep2=Depends(dependency2)) -> str:
print(f"dependency3 (sub {dep2})")
return f"dep3"
class DependencyExtension(FieldExtension):
def __init__(self):
self.dependency_args: list[typing.Any] = []
self.dependency_overrides_provider = None
self.strawberry_field_names = ["self"]
def apply(self, field: StrawberryField) -> None:
self.dependant = deputils.get_dependant(
path="/", call=field.base_resolver.wrapped_func
)
# Remove fastapi Depends arguments from the list that strawberry tries to resolve
field.arguments = [
item
for item in field.arguments
if not isinstance(item.default, DependsClass)
]
# Stash the list of strawberry argument names so we can safely ignore them later
self.strawberry_field_names += [item.python_name for item in field.arguments]
async def resolve_async(
self,
next_: typing.Callable[..., typing.Any],
source: typing.Any,
info: Info,
**kwargs,
) -> typing.Any:
request = info.context["request"]
try:
if "dependency_cache" not in request.context:
request.context["dependency_cache"] = {}
except AttributeError:
request.context = {"dependency_cache": {}}
solved_result = await deputils.solve_dependencies(
request=request,
dependant=self.dependant,
body={},
dependency_overrides_provider=request.app,
dependency_cache=request.context["dependency_cache"],
)
(
solved_values,
_, # solver_errors. It shouldn't be possible for it to contain
# anything relevant to this extension.
_, # background tasks
_, # the subdependency returns the same response we have
new_cache, # sub_dependency_cache
) = solved_result
request.context["dependency_cache"].update(new_cache)
kwargs = solved_values | kwargs # solved_values has None values that need to be overridden by kwargs
res = await next_(source, info, **kwargs)
return res
@strawberry.type
class ChildModel:
name: str
value: str
@strawberry.type
class ParentModel:
name: str
@strawberry.field(extensions=[DependencyExtension()])
async def related(
self,
dep1: str = Depends(dependency1),
dep2: str = Depends(dependency2),
) -> ChildModel:
return ChildModel(name=dep1, value=dep2)
@strawberry.type
class Query:
@strawberry.field(extensions=[DependencyExtension()])
def get_model(
self, strawberry_field: str, name: str = Depends(dependency3)
) -> ParentModel:
return ParentModel(name=f"{strawberry_field} {name}")
def get_context():
return {
"this_still_works": {},
}
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter(schema, context_getter=get_context, graphiql=True)
app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")
if __name__ == "__main__":
config = uvicorn.Config("main:app", host="0.0.0.0", port=8008, log_level="info")
server = uvicorn.Server(config)
server.run()
@aurthurm
Copy link

Hie @jgadling

Thank yu for this awesome extension. I was wndering if there is a way for making it global like this
schema = strawberry.Schema(query=Query, extensions=[DependencyExtension()])

Thanks

@jun-soundsleep
Copy link

jun-soundsleep commented Dec 14, 2023

hi @jgadling
Thank you for this awesome extension.
But, i have problem. could you help me?

For example,

from strawberry.types import Info

def resolver_to_update_contract(
    self,
    customer_id: str,
    api_quota: int,
    api_plan_id: int,
    ended_on: datetime,
    info: Info,
) -> UpdateContractResponse:
...

if i have strawberry.types.info, your dependency extension is not working. how can i fix it ?

Thanks for Contribution.

@anon-dev-gh
Copy link

Hie @jgadling

Thank yu for this awesome extension. I was wndering if there is a way for making it global like this schema = strawberry.Schema(query=Query, extensions=[DependencyExtension()])

Thanks

I wouldn't make it global, but on a field basis.

Consider defining a decorator

def add_dependency_extension(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if 'extensions' in kwargs and isinstance(kwargs['extensions'], list):
            kwargs['extensions'].append(DependencyExtension())
        else:
            kwargs['extensions'] = [DependencyExtension()]
        return func(*args, **kwargs)
    return wrapper

and a wrapper

@add_dependency_extension
def dependencies_field(*args, **kwargs):
    return strawberry.field(*args, **kwargs)

and then

@type
class Query:
    books: typing.List[Book] = dependencies_field(resolver=get_books)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment