Skip to content

Instantly share code, notes, and snippets.

@ThirVondukr
Last active October 6, 2021 23:36
Show Gist options
  • Save ThirVondukr/27668d96d492b456b6c68b977efb17d8 to your computer and use it in GitHub Desktop.
Save ThirVondukr/27668d96d492b456b6c68b977efb17d8 to your computer and use it in GitHub Desktop.
Strawberry Input types generation
import dataclasses
import typing
from types import SimpleNamespace
from typing import Optional, List
import strawberry
class Op(SimpleNamespace):
eq = "eq"
neq = "neq"
lt = "lt"
lte = "lte"
gt = "gt"
gte = "gte"
contains = "contains"
not_contains = "not_contains"
in_ = "in_"
not_in = "not_in"
_OP_COMPARISONS = {Op.eq, Op.neq, Op.lt, Op.lte, Op.gt, Op.gte}
_SAME_TYPE_OP = _OP_COMPARISONS
_INCLUSION_OP = {Op.in_, Op.not_in}
_CONTAINS_OP = {Op.contains, Op.not_contains}
FILTER_MAP: dict[type, set[str]] = {
bool: {Op.eq, Op.neq},
int: {*_OP_COMPARISONS, *_INCLUSION_OP},
str: {*_OP_COMPARISONS, *_INCLUSION_OP, *_CONTAINS_OP},
set: {*_CONTAINS_OP},
list: {*_CONTAINS_OP},
}
def create_filter_name(type_):
generics = typing.get_args(type_)
return "".join(g.__name__.capitalize() for g in generics) + type_.__name__.capitalize() + "Filter"
def create_filter(type_: type):
operations = FILTER_MAP[typing.get_origin(type_) or type_]
fields = []
for op in operations:
if op in _SAME_TYPE_OP:
fields.append((op, Optional[type_], dataclasses.field(default=None)))
elif op in _INCLUSION_OP:
fields.append((op, Optional[List[type_]], dataclasses.field(default=None)))
elif op in _CONTAINS_OP:
generic_args = typing.get_args(type_)
container_type = typing.get_origin(type_) or type_
if len(generic_args) == 1:
resulting_type = Optional[container_type[generic_args[0]]]
else:
resulting_type = Optional[container_type]
fields.append((op, resulting_type, dataclasses.field(default=None)))
filter_ = dataclasses.make_dataclass(
create_filter_name(type_),
fields=fields
)
return strawberry.input(filter_)
StrFilter = create_filter(str)
IntFilter = create_filter(int)
BoolFilter = create_filter(bool)
@strawberry.type
class Root:
@strawberry.field
def test(
self,
number: filters.IntFilter,
boolean: filters.BoolFilter,
string: filters.StrFilter,
int_list_filter: filters.create_filter(list[int]),
str_list_filter: filters.create_filter(list[str]),
) -> int:
return 42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment