Skip to content

Instantly share code, notes, and snippets.

@wapiflapi
Created September 29, 2019 23:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wapiflapi/4e827b9b3afb5b05c3818bfed35de544 to your computer and use it in GitHub Desktop.
Save wapiflapi/4e827b9b3afb5b05c3818bfed35de544 to your computer and use it in GitHub Desktop.
# @wapiflapi gqldl early draft.
#
# The goal of this is to manage dataloaders for multiple types while
# at the same time providing an easy integration for relay-compliance.
# - There will be full utility support for the relay spec.
# - Integration with ariadne WILL be easy and documented.
# - Integration with graphene MIGHT be documented.
#
# Feel free to comment, but documentation, tests and a proper release
# are comming soon(tm).
import abc
import base64
import binascii
import functools
import aiodataloader
from papi.logging import logger
import gc; gc.disable()
class DataLoader(aiodataloader.DataLoader):
def __init__(self, *args, context=None, **kwargs):
print("CUSTOM: %s, %s" % (args, kwargs))
self.context = context
super().__init__(*args, **kwargs)
def relay_connection(resolver):
@functools.wraps(resolver)
async def wrapped_resolver(obj, info, *args, **kwargs):
hasprevpage, hasnextpage, items = await resolver(
obj, info, *args, **kwargs)
for key, cursor in items:
if not isinstance(cursor, str):
raise TypeError(f"Invalid cursor {cursor}: must be str.")
return {
"pageInfo": {
"hasPreviousPage": hasprevpage,
"hasNextPage": hasnextpage,
"startCursor": items[0][1] if items else None,
"endCursor": items[-1][1] if items else None,
},
"edges": [{
"cursor": cursor,
"key": key,
} for key, cursor in items]
}
return wrapped_resolver
class Relay(abc.ABC):
@abc.abstractmethod
async def get_gdl(self, obj, info):
pass
@abc.abstractmethod
async def setup_edge_node(self, edgename, typename, resolver):
pass
def __init__(self):
self.typename_resolver_map = {}
def resolve(self, typename):
def decorator(resolver):
self.typename_resolver_map[typename] = resolver
return resolver
return decorator
async def resolve_object(self, root, info, obj):
if isinstance(obj, dict):
typename = obj.get("__typename", None)
else:
# TODO: I'm not sure this works. Test it.
# (because of the __ being special.)
typename = getattr(obj, "__typename", None)
if typename is None:
raise TypeError(
f"No __typename attribute or key found in {obj}.")
try:
resolver = self.typename_resolver_map[typename]
except KeyError:
raise NotImplementedError(
f"No resolver registered for {typename} using relay.resolve()")
# TODO: Should we still handle this being an object and not a dict ?
resolved = await resolver(root, info, obj)
if not "id" in resolved:
try:
key = resolved["key"]
except KeyError:
raise TypeError(
f"No id attribute or key found in {obj} and no key provided.")
resolved["id"] = self.get_gdl(root, info).to_global_id(
typename, key,
)
return resolved
async def resolve_global_id(self, root, info, id):
return await self.resolve_object(
root, info, await self.get_gdl(root, info).load_global_id(id))
async def resolve_type_key(self, root, info, typename, key, index="id"):
return await self.resolve_object(
root, info, await self.get_gdl(root, info).load_type_key(
typename, key, index=index))
def connection(self, typename, index="id"):
async def resolver(root, info):
return await self.resolve_type_key(
root, info, typename=typename, key=root["key"], index=index)
self.setup_edge_node(f"{typename}Edge", typename, resolver)
return relay_connection
class GlobalDataLoader(abc.ABC):
@staticmethod
def to_global_id(typename, key):
# Let's not serialize/un-serialize keys: not our job.
if not typename or not isinstance(typename, str):
raise TypeError(
f"Invalid typename {typename}: must be non-empty str.")
if not key or not isinstance(key, str):
raise TypeError(
f"Invalid key {key}: must be non-empty str.")
gid = f"{typename}:{key}"
return base64.b64encode(gid.encode("utf8")).decode("utf8")
@staticmethod
def from_global_id(gid):
try:
gid = base64.b64decode(gid.encode("utf8")).decode("utf8")
except (binascii.Error, UnicodeDecodeError):
raise TypeError(f"Received invalid gid {gid}")
typename, _, key = gid.partition(':')
if not typename or not key:
raise TypeError(f"Invalid global ID {gid}")
return typename, key
@classmethod
def register_loadertype(cls, typename, loadertype, index="id"):
if not (isinstance(loadertype, type)
and issubclass(loadertype, aiodataloader.DataLoader)):
loadertype = functools.partial(DataLoader, loadertype)
cls.typename_loadertype_map[(typename, index)] = loadertype
@classmethod
def loadertype(cls, typename, index="id"):
return functools.partial(
cls.register_loadertype, typename, index=index)
@classmethod
def enforce_typed_objects(cls, typename, objects, index="id", keys=None):
if keys is not None and len(keys) != len(objects):
raise TypeError(
"keys should be None or the same length as objects.")
for i, obj in enumerate(objects):
if keys is not None and index == "id":
gid = cls.to_global_id(typename, keys[i])
elif isinstance(obj, dict):
gid = obj.get("id", None)
else:
gid = getattr(obj, "id", None)
if gid is None:
raise TypeError(
f"Loader for {typename} with index={index} returned "
f"something without 'id' attribute or key."
)
if isinstance(obj, dict):
assert obj.get("id", gid) == gid
obj["__typename"] = typename
obj["id"] = gid
else:
assert getattr(obj, "id", gid) == gid
setattr(obj, "__typename", typename)
setattr(obj, "id", gid)
return objects
@property
@abc.abstractmethod
def typename_loadertype_map(self):
pass
def __init__(self, context=None):
self.typename_loader_map = {
(typename, index): loadertype(context=context)
for (typename, index), loadertype
in self.typename_loadertype_map.items()
}
def get_type_loader(self, typename, index="id"):
try:
return self.typename_loader_map[(typename, index)]
except KeyError:
raise TypeError(
f"No loader registered for {typename} with index={index}.")
async def load_global_id(self, gid):
typename, key = self.from_global_id(gid)
return await self.load_type_key(typename, key)
async def load_global_ids(self, gids):
return [await self.load_global_id(gid) for gid in gids]
async def load_type_key(self, typename, key, index="id"):
return (await self.load_type_keys(typename, [key], index=index))[0]
async def load_type_keys(self, typename, keys, index="id"):
loader = self.get_type_loader(typename, index=index)
return self.enforce_typed_objects(
typename, await loader.load_many(keys), index=index, keys=keys)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment