Skip to content

Instantly share code, notes, and snippets.

@lukesmurray
Last active October 8, 2021 21:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukesmurray/4d0f156a7f97dc46c0a98f684794b999 to your computer and use it in GitHub Desktop.
Save lukesmurray/4d0f156a7f97dc46c0a98f684794b999 to your computer and use it in GitHub Desktop.
strawberry_sqlalchemy updates
from __future__ import annotations
import typing as t
import strawberry
from sqlmodel import Field, Relationship, SQLModel
from .filter_generators import create_array_relationship_type, create_query_root
class AddressModel(SQLModel, table=True):
__tablename__ = "addresses"
id: t.Optional[int] = Field(
default=None, primary_key=True, index=True, nullable=False
)
street: str
state: str
country: str
zip: str
users: t.List["UserModel"] = Relationship(back_populates="address")
class UserModel(SQLModel, table=True):
__tablename__ = "users"
id: t.Optional[int] = Field(
default=None, primary_key=True, index=True, nullable=False
)
age: int
password: t.Optional[str]
address_id: t.Optional[int] = Field(default=None, foreign_key="addresses.id")
address: t.Optional[AddressModel] = Relationship(back_populates="users")
@strawberry.experimental.pydantic.type(
UserModel, fields=["id", "age", "password", "address_id", "address"]
)
class User:
pass
@strawberry.experimental.pydantic.type(
AddressModel, fields=["id", "street", "state", "country", "zip"]
)
class Address:
users: t.List[create_array_relationship_type(User)] = strawberry.field(
resolver=create_array_relationship_type(User)
)
Query = create_query_root([User, Address])
schema = strawberry.Schema(query=Query)
schema {
query: queryRoot
}
type Address {
street: String!
state: String!
country: String!
zip: String!
users(where: UserFilter = null, limit: Int = null, offset: Int = null, orderBy: UserOrderBy = null, distinctOn: [UsersSelectColumn!] = null): [User!]!
id: Int
}
input AddressFilter {
street: StrFilter = null
state: StrFilter = null
country: StrFilter = null
zip: StrFilter = null
users: UserFilter = null
id: IntFilter = null
and_: [AddressFilter!] = null
or_: [AddressFilter!] = null
}
input AddressOrderBy {
street: OrderByEnum = null
state: OrderByEnum = null
country: OrderByEnum = null
zip: OrderByEnum = null
users: OrderByEnum = null
id: OrderByEnum = null
}
enum AddressSelectColumn {
street
state
country
zip
users
id
}
input IntFilter {
notIn: [Int!] = null
lt: Int = null
eq: Int = null
neq: Int = null
gte: Int = null
in_: [Int!] = null
lte: Int = null
gt: Int = null
isNull: Boolean = null
}
enum OrderByEnum {
asc
asc_nulls_first
asc_nulls_last
desc
desc_nulls_first
desc_nulls_last
}
input StrFilter {
notContains: String = null
lte: String = null
notIn: [String!] = null
contains: String = null
lt: String = null
eq: String = null
neq: String = null
gte: String = null
in_: [String!] = null
gt: String = null
isNull: Boolean = null
}
type User {
age: Int!
id: Int
password: String
addressId: Int
}
input UserFilter {
age: IntFilter = null
id: IntFilter = null
password: StrFilter = null
addressId: IntFilter = null
and_: [UserFilter!] = null
or_: [UserFilter!] = null
}
input UserOrderBy {
age: OrderByEnum = null
id: OrderByEnum = null
password: OrderByEnum = null
addressId: OrderByEnum = null
}
enum UsersSelectColumn {
age
id
password
address_id
}
type queryRoot {
allUsers(where: UserFilter = null, limit: Int = null, offset: Int = null, orderBy: UserOrderBy = null, distinctOn: [UsersSelectColumn!] = null): [User!]!
allAddress(where: AddressFilter = null, limit: Int = null, offset: Int = null, orderBy: AddressOrderBy = null, distinctOn: [AddressSelectColumn!] = null): [Address!]!
}
import collections
import dataclasses
import enum
import typing as t
from enum import Enum
from types import SimpleNamespace
import strawberry
from strawberry.type import StrawberryContainer
class BoolOps(SimpleNamespace):
eq = "eq"
neq = "neq"
lt = "lt"
lte = "lte"
gt = "gt"
gte = "gte"
contains = "contains"
not_contains = "not_contains"
in_ = "in_"
not_in = "not_in"
is_null_ = "is_null"
_BOOL_OP_COMPARISONS = {
BoolOps.eq,
BoolOps.neq,
BoolOps.lt,
BoolOps.lte,
BoolOps.gt,
BoolOps.gte,
}
_SAME_TYPE_BOOL_OP = _BOOL_OP_COMPARISONS
_INCLUSION_BOOL_OP = {BoolOps.in_, BoolOps.not_in}
_CONTAINS_BOOL_OP = {BoolOps.contains, BoolOps.not_contains}
_SCALAR_BOOL_OP_MAP: dict[type, set[str]] = {
bool: {BoolOps.eq, BoolOps.neq},
int: {*_BOOL_OP_COMPARISONS, *_INCLUSION_BOOL_OP},
float: {*_BOOL_OP_COMPARISONS, *_INCLUSION_BOOL_OP},
str: {*_BOOL_OP_COMPARISONS, *_INCLUSION_BOOL_OP, *_CONTAINS_BOOL_OP},
set: {*_CONTAINS_BOOL_OP},
list: {*_CONTAINS_BOOL_OP},
}
@strawberry.enum
class OrderByEnum(Enum):
asc = "asc"
asc_nulls_first = "asc_nulls_first"
asc_nulls_last = "asc_nulls_last"
desc = "desc"
desc_nulls_first = "desc_nulls_first"
desc_nulls_last = "desc_nulls_last"
PRIMITIVES = {int, str, bool, float}
def is_sequence_container(type_):
"""Check if a type is a container. For example t.List[int] is a container,"""
# TODO: this is hacky.
origin = t.get_origin(type_)
return origin is t.List or origin is t.Set
def is_optional(type_):
"""Check if a type is optional. For example t.Optional[int] is optional,"""
return t.get_origin(type_) is t.Union and type(None) in t.get_args(type_)
def unwrap_sequence_container(type_):
"""Return the inside of a container. For example t.List[int]
would return int.
"""
if is_sequence_container(type_):
if len(t.get_args(type_)) > 1:
raise ValueError(
"Unable to unwrap fields which may contain multiple types "
+ f"of scalars: {type_}"
)
return t.get_args(type_)[0]
return type_
def unwrap_optional(type_):
"""Return the non optional version of a type. For example t.Optional[int]
would return int.
"""
if is_optional(type_):
if len(t.get_args(type_)) > 2:
raise ValueError(
"Unable to unwrap fields which may contain multiple types "
+ f"of scalars: {type_}"
)
return [t_ for t_ in t.get_args(type_) if t_ is not None][0]
return type_
def is_primitive(type_):
# TODO: this is hacky. We want to understand if types are primitive or not
# but really we just want to know if they are handled by sql as a column
# or as a relationship
# int is a column
# list[int] could be a column (we can assume it is)
# list[address] is a relationship
while is_optional(type_) or is_sequence_container(type_):
type_ = unwrap_optional(type_)
type_ = unwrap_sequence_container(type_)
return isinstance(type_, collections.Hashable) and type_ in PRIMITIVES
def create_comparison_expression_name(type_):
generics = t.get_args(type_)
return (
"".join(g.__name__.capitalize() for g in generics)
+ type_.__name__.capitalize()
+ "Filter"
)
def create_order_by_expression_name(type_):
generics = t.get_args(type_)
return (
"".join(g.__name__.capitalize() for g in generics)
+ type_.__name__.capitalize()
+ "OrderBy"
)
def create_all_type_query_name(type_):
type_name = type_.__name__.capitalize()
if not type_name.endswith("s"):
type_name += "s"
return f"all_{type_name}"
def create_select_column_enum_name(type_):
type_name = type_.__name__.capitalize()
if not type_name.endswith("s"):
type_name += "s"
return f"{type_name}SelectColumn"
def create_scalar_comparison_expression(type_: type):
type_ = unwrap_optional(type_)
operations = _SCALAR_BOOL_OP_MAP[t.get_origin(type_) or type_]
fields = []
for op in operations:
if op in _SAME_TYPE_BOOL_OP:
fields.append((op, t.Optional[type_], dataclasses.field(default=None)))
elif op in _INCLUSION_BOOL_OP:
fields.append(
(op, t.Optional[t.List[type_]], dataclasses.field(default=None))
)
elif op in _CONTAINS_BOOL_OP:
generic_args = t.get_args(type_)
container_type = t.get_origin(type_) or type_
if len(generic_args) == 1:
resulting_type = t.Optional[container_type[generic_args[0]]]
else:
resulting_type = t.Optional[container_type]
fields.append((op, resulting_type, dataclasses.field(default=None)))
# TODO: would be nice to only add is_null if the field is optional.
# but we would need to change the `expression_name` since we register
# the expression as a global and whichever class we define last will
# override prior classes
fields.append((BoolOps.is_null_, t.Optional[bool], dataclasses.field(default=None)))
expression_name = create_comparison_expression_name(type_)
globals()[expression_name] = dataclasses.make_dataclass(
expression_name,
fields=fields,
namespace={"__module__": __name__},
)
return strawberry.input(globals()[expression_name])
def create_non_scalar_comparison_expression(type_: type):
type_hints = t.get_type_hints(type_)
fields = []
expression_name = create_comparison_expression_name(type_)
for field_name, field_type in type_hints.items():
if is_primitive(field_type):
fields.append(
(
field_name,
t.Optional[create_scalar_comparison_expression(field_type)],
dataclasses.field(default=None),
)
)
else:
# here we have a nested non scalar type. If the field is a single
# item then we just want to be able to query the fields of the item
# if the field is an array we need to create a way to query the array
# TODO: create a way to query the array
# the base type is the underlying type of the field.
# we don't care if the field is optional or a list we just want
# to implement a filter for the underlying type
# TODO: not sure if StrawberryContainer is always the right choice
field_base_type = field_type
if isinstance(field_type, StrawberryContainer):
field_base_type = field_type.of_type
# this code handles the case where the field is a single item
fields.append(
(
field_name,
t.Optional[
create_non_scalar_comparison_expression(field_base_type)
],
dataclasses.field(default=None),
)
)
fields.append(
("and_", t.Optional[t.List[expression_name]], dataclasses.field(default=None)),
)
fields.append(
("or_", t.Optional[t.List[expression_name]], dataclasses.field(default=None)),
)
globals()[expression_name] = dataclasses.make_dataclass(
expression_name,
fields=fields,
namespace={"__module__": __name__},
)
return strawberry.input(globals()[expression_name])
def create_non_scalar_order_by_expression(type_: type):
type_hints = t.get_type_hints(type_)
fields = []
expression_name = create_order_by_expression_name(type_)
for field_name, field_type in type_hints.items():
fields.append(
(
field_name,
t.Optional[OrderByEnum],
dataclasses.field(default=None),
)
)
globals()[expression_name] = dataclasses.make_dataclass(
expression_name,
fields=fields,
namespace={"__module__": __name__},
)
return strawberry.input(globals()[expression_name])
def create_non_scalar_select_columns_enum(type_: type):
enum_name = create_select_column_enum_name(type_)
type_hints = t.get_type_hints(type_)
globals()[enum_name] = enum.Enum(
enum_name, {field_name: field_name for field_name in type_hints.keys()}
)
return strawberry.enum(globals()[enum_name])
def create_array_relationship_type(type_: type):
def all_type_query_implementation(
self,
info,
where: t.Optional[create_non_scalar_comparison_expression(type_)] = None,
limit: t.Optional[int] = None,
offset: t.Optional[int] = None,
orderBy: t.Optional[create_non_scalar_order_by_expression(type_)] = None,
distinctOn: t.Optional[
t.List[create_non_scalar_select_columns_enum(type_)]
] = None,
) -> t.List[type_]:
# TODO: actually implement the query
if type_.__name__ == "User":
return [type_(age=10, password="foo")]
elif type_.__name__ == "Address":
return [type_(street="harman", state="ny", country="usa", zip="11237")]
return all_type_query_implementation
def create_all_type_query_field(type_: type):
method_name = create_all_type_query_name(type_)
all_type_query_implementation = create_array_relationship_type(type_)
return (
method_name,
t.List[type_],
dataclasses.field(default=strawberry.field(all_type_query_implementation)),
)
def create_query_root(types: t.List[type]):
all_type_queries = [create_all_type_query_field(type_) for type_ in types]
query_root_name = "query_root"
globals()[query_root_name] = dataclasses.make_dataclass(
query_root_name,
fields=[*all_type_queries],
namespace={
**{"__module__": __name__},
},
)
return strawberry.type(globals()[query_root_name])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment