Skip to content

Instantly share code, notes, and snippets.

@mairas
Created September 4, 2023 15:23
Show Gist options
  • Save mairas/ec8ec704758210cc20b5e5782713c0d7 to your computer and use it in GitHub Desktop.
Save mairas/ec8ec704758210cc20b5e5782713c0d7 to your computer and use it in GitHub Desktop.
How to access parent field arguments in Strawberry GraphQL
# Let's say you have a user with id "XXX" and you want to get all the
# groups that user is a member of. The query below should do the trick.
#
# However, when implementing the query and the types with Strawberry,
# it becomes essential that the groups relay connection has access to the
# user id. This is a hassle in Strawberry because the information is only
# available in the abstract syntax tree that you have to traverse
# manually.
#
# The code in this gist shows how to do that. It's not pretty, but it works.
# Call `parent_field(info, 2, "id")` in the groups implementation to get the user id.
#
# {
# user(
# id: "XXX"
# ) {
# groups {
# edges {
# node {
# id
# }
# }
# }
# }
# }
def _flatten_paths(path: graphql.pyutils.path.Path, paths=None) -> list[graphql.pyutils.path.Path]:
"""
Flatten a reursive GraphQL path into a list of paths.
"""
if paths is None:
paths = []
paths.append(path)
if path.prev is not None:
_flatten_paths(path.prev, paths)
return paths
def _extract_field_args(field: graphql.language.ast.FieldNode) -> dict[str, Any]:
"""
Extract the arguments for a field.
"""
args = {}
argument: graphql.language.ast.ArgumentNode
for argument in field.arguments:
args[argument.name.value] = argument.value.value
return args
def _descendant_fields(
selections: tuple[graphql.language.ast.SelectionNode, ...], paths: list[graphql.pyutils.path.Path]
) -> list[graphql.language.ast.FieldNode]:
"""
Traverse the AST to find the field nodes for each path.
"""
path_item_name = paths[0].key
for selection in selections:
if isinstance(selection, graphql.language.ast.FieldNode):
if selection.name.value == path_item_name:
sset = selection.selection_set
if sset is None:
return [selection]
descendants = _descendant_fields(sset.selections, paths[1:])
return [selection, *descendants]
return []
def ancestor_args(info: strawberry.types.Info) -> list[dict[str, Any]]:
"""
Return the arguments of the ancestor fields of the current field.
Returns a list of dictionaries, each dictionary containing the arguments
for a single ancestor field. The list is ordered from the closest
ancestor to the furthest.
"""
paths = _flatten_paths(info.path)
# Traverse the ast to find the field nodes for each path.
# For each field node along the path, extract the arguments.
rev_paths = paths[::-1]
field_nodes = _descendant_fields(info.operation.selection_set.selections, rev_paths)
args_list = [_extract_field_args(field) for field in field_nodes]
return args_list[::-1]
def parent_field(info: strawberry.types.Info, ancestor_level: int, field_name: str) -> Any:
"""
Return the value of a field on an ancestor of the current field.
"""
args_list = ancestor_args(info)
return args_list[ancestor_level][field_name]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment