Created
July 28, 2023 01:24
-
-
Save jgadling/c23eebf4a08c8db199df2a4fd70bf555 to your computer and use it in GitHub Desktop.
Example of using strawberry FieldExtension with dependency_injector
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
import strawberry | |
import uvicorn | |
from dependency_injector import containers, providers | |
from dependency_injector.wiring import Provide, inject | |
from fastapi import FastAPI | |
from strawberry.extensions import FieldExtension | |
from strawberry.fastapi import GraphQLRouter | |
from strawberry.field import StrawberryField | |
from strawberry.types import Info | |
class Container(containers.DeclarativeContainer): | |
dep1 = providers.Dependency() | |
dep2 = providers.Dependency() | |
dep3 = providers.Dependency() | |
def dependency1(): | |
print("dependency1") | |
return "dep1" | |
def dependency2(dep1=Provide[Container.dep1]): | |
print("dependency2") | |
return f"dep2 (sub: {dep1})" | |
def dependency3(dep2=Provide[Container.dep2]): | |
print(f"dependency3 (sub {dep2})") | |
return f"dep3" | |
class DependencyExtension(FieldExtension): | |
def __init__(self): | |
self.dependency_args: list[typing.Any] = [] | |
def apply(self, field: StrawberryField) -> None: | |
# Remove dependency_injector provider arguments from the list that strawberry tries to resolve | |
di_arguments = [] | |
keep_arguments = [] | |
for arg in field.arguments: | |
if isinstance(arg.default, Provide): | |
di_arguments.append(arg) | |
continue | |
keep_arguments.append(arg) | |
field.arguments = keep_arguments | |
self.dependency_args = di_arguments | |
async def resolve_async( | |
self, | |
next_: typing.Callable[..., typing.Any], | |
source: typing.Any, | |
info: Info, | |
**kwargs, | |
) -> typing.Any: | |
res = await next_(source, info, **kwargs) | |
return res | |
@strawberry.type | |
class ChildModel: | |
name: str | |
value: str | |
@strawberry.type | |
class ParentModel: | |
name: str | |
@strawberry.field(extensions=[DependencyExtension()]) | |
@inject | |
async def related( | |
self, | |
dep1: str = Provide[Container.dep1], | |
dep2: str = Provide[Container.dep2], | |
) -> ChildModel: | |
return ChildModel(name=dep1, value=dep2) | |
@strawberry.type | |
class Query: | |
@strawberry.field(extensions=[DependencyExtension()]) | |
@inject | |
def get_model( | |
self, strawberry_field: str, name: str = Provide[Container.dep3] | |
) -> ParentModel: | |
return ParentModel(name=f"{strawberry_field} {name}") | |
def get_context(): | |
container = Container( | |
dep1=providers.Callable(dependency1), | |
dep2=providers.Callable(dependency2), | |
dep3=providers.Callable(dependency3), | |
) | |
container.init_resources() | |
container.wire(modules=[__name__]) | |
return { | |
"container": container, | |
} | |
schema = strawberry.Schema(query=Query) | |
graphql_app = GraphQLRouter(schema, context_getter=get_context, graphiql=True) | |
app = FastAPI() | |
app.include_router(graphql_app, prefix="/graphql") | |
if __name__ == "__main__": | |
config = uvicorn.Config("main:app", host="0.0.0.0", port=8008, log_level="info") | |
server = uvicorn.Server(config) | |
server.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment