Skip to content

Instantly share code, notes, and snippets.

@jgadling
Created July 28, 2023 01:24
Show Gist options
  • Save jgadling/c23eebf4a08c8db199df2a4fd70bf555 to your computer and use it in GitHub Desktop.
Save jgadling/c23eebf4a08c8db199df2a4fd70bf555 to your computer and use it in GitHub Desktop.
Example of using strawberry FieldExtension with dependency_injector
import typing
import strawberry
import uvicorn
from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from fastapi import FastAPI
from strawberry.extensions import FieldExtension
from strawberry.fastapi import GraphQLRouter
from strawberry.field import StrawberryField
from strawberry.types import Info
class Container(containers.DeclarativeContainer):
dep1 = providers.Dependency()
dep2 = providers.Dependency()
dep3 = providers.Dependency()
def dependency1():
print("dependency1")
return "dep1"
def dependency2(dep1=Provide[Container.dep1]):
print("dependency2")
return f"dep2 (sub: {dep1})"
def dependency3(dep2=Provide[Container.dep2]):
print(f"dependency3 (sub {dep2})")
return f"dep3"
class DependencyExtension(FieldExtension):
def __init__(self):
self.dependency_args: list[typing.Any] = []
def apply(self, field: StrawberryField) -> None:
# Remove dependency_injector provider arguments from the list that strawberry tries to resolve
di_arguments = []
keep_arguments = []
for arg in field.arguments:
if isinstance(arg.default, Provide):
di_arguments.append(arg)
continue
keep_arguments.append(arg)
field.arguments = keep_arguments
self.dependency_args = di_arguments
async def resolve_async(
self,
next_: typing.Callable[..., typing.Any],
source: typing.Any,
info: Info,
**kwargs,
) -> typing.Any:
res = await next_(source, info, **kwargs)
return res
@strawberry.type
class ChildModel:
name: str
value: str
@strawberry.type
class ParentModel:
name: str
@strawberry.field(extensions=[DependencyExtension()])
@inject
async def related(
self,
dep1: str = Provide[Container.dep1],
dep2: str = Provide[Container.dep2],
) -> ChildModel:
return ChildModel(name=dep1, value=dep2)
@strawberry.type
class Query:
@strawberry.field(extensions=[DependencyExtension()])
@inject
def get_model(
self, strawberry_field: str, name: str = Provide[Container.dep3]
) -> ParentModel:
return ParentModel(name=f"{strawberry_field} {name}")
def get_context():
container = Container(
dep1=providers.Callable(dependency1),
dep2=providers.Callable(dependency2),
dep3=providers.Callable(dependency3),
)
container.init_resources()
container.wire(modules=[__name__])
return {
"container": container,
}
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter(schema, context_getter=get_context, graphiql=True)
app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")
if __name__ == "__main__":
config = uvicorn.Config("main:app", host="0.0.0.0", port=8008, log_level="info")
server = uvicorn.Server(config)
server.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment