Created
September 4, 2023 15:23
-
-
Save mairas/ec8ec704758210cc20b5e5782713c0d7 to your computer and use it in GitHub Desktop.
How to access parent field arguments in Strawberry GraphQL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Let's say you have a user with id "XXX" and you want to get all the | |
# groups that user is a member of. The query below should do the trick. | |
# | |
# However, when implementing the query and the types with Strawberry, | |
# it becomes essential that the groups relay connection has access to the | |
# user id. This is a hassle in Strawberry because the information is only | |
# available in the abstract syntax tree that you have to traverse | |
# manually. | |
# | |
# The code in this gist shows how to do that. It's not pretty, but it works. | |
# Call `parent_field(info, 2, "id")` in the groups implementation to get the user id. | |
# | |
# { | |
# user( | |
# id: "XXX" | |
# ) { | |
# groups { | |
# edges { | |
# node { | |
# id | |
# } | |
# } | |
# } | |
# } | |
# } | |
def _flatten_paths(path: graphql.pyutils.path.Path, paths=None) -> list[graphql.pyutils.path.Path]: | |
""" | |
Flatten a reursive GraphQL path into a list of paths. | |
""" | |
if paths is None: | |
paths = [] | |
paths.append(path) | |
if path.prev is not None: | |
_flatten_paths(path.prev, paths) | |
return paths | |
def _extract_field_args(field: graphql.language.ast.FieldNode) -> dict[str, Any]: | |
""" | |
Extract the arguments for a field. | |
""" | |
args = {} | |
argument: graphql.language.ast.ArgumentNode | |
for argument in field.arguments: | |
args[argument.name.value] = argument.value.value | |
return args | |
def _descendant_fields( | |
selections: tuple[graphql.language.ast.SelectionNode, ...], paths: list[graphql.pyutils.path.Path] | |
) -> list[graphql.language.ast.FieldNode]: | |
""" | |
Traverse the AST to find the field nodes for each path. | |
""" | |
path_item_name = paths[0].key | |
for selection in selections: | |
if isinstance(selection, graphql.language.ast.FieldNode): | |
if selection.name.value == path_item_name: | |
sset = selection.selection_set | |
if sset is None: | |
return [selection] | |
descendants = _descendant_fields(sset.selections, paths[1:]) | |
return [selection, *descendants] | |
return [] | |
def ancestor_args(info: strawberry.types.Info) -> list[dict[str, Any]]: | |
""" | |
Return the arguments of the ancestor fields of the current field. | |
Returns a list of dictionaries, each dictionary containing the arguments | |
for a single ancestor field. The list is ordered from the closest | |
ancestor to the furthest. | |
""" | |
paths = _flatten_paths(info.path) | |
# Traverse the ast to find the field nodes for each path. | |
# For each field node along the path, extract the arguments. | |
rev_paths = paths[::-1] | |
field_nodes = _descendant_fields(info.operation.selection_set.selections, rev_paths) | |
args_list = [_extract_field_args(field) for field in field_nodes] | |
return args_list[::-1] | |
def parent_field(info: strawberry.types.Info, ancestor_level: int, field_name: str) -> Any: | |
""" | |
Return the value of a field on an ancestor of the current field. | |
""" | |
args_list = ancestor_args(info) | |
return args_list[ancestor_level][field_name] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment