Skip to content

Instantly share code, notes, and snippets.

@mdellavo
Last active January 16, 2024 15:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mdellavo/443d206b1800887b0b64cad92f734472 to your computer and use it in GitHub Desktop.
Save mdellavo/443d206b1800887b0b64cad92f734472 to your computer and use it in GitHub Desktop.
WIP: Document mapper for FoundationDB
import dataclasses
import datetime
import itertools
import struct
from typing import TypeVar, Optional, Any, Type, TypedDict
import fdb
from fdb.impl import Transaction
from fdb.subspace_impl import Subspace
import varint
fdb.api_version(710)
TYPE_REGISTRY = {}
class Field:
TYPE = bytes
def __init_subclass__(cls, **kwargs):
TYPE_REGISTRY[cls.TYPE] = cls
def __init__(self, key=None):
self.key = key
@classmethod
def encode(cls, value: bytes) -> bytes:
return value
@classmethod
def decode(cls, value: bytes) -> bytes:
return value
TYPE_REGISTRY[bytes] = Field
FieldType = TypeVar('FieldType', bound=Field)
class BoolField(Field):
TYPE = bool
@classmethod
def encode(cls, value: bool) -> bytes:
return b'\x01' if bool else b'\x00'
@classmethod
def decode(cls, value: bytes) -> bool:
return bool(value[0])
class IntegerField(Field):
TYPE = int
@classmethod
def encode(cls, value: int) -> bytes:
return varint.encode(value)
@classmethod
def decode(cls, value: bytes) -> int:
return varint.decode_bytes(value)
class FloatField(Field):
TYPE = float
@classmethod
def encode(cls, value: float) -> bytes:
return struct.pack("d", value)
@classmethod
def decode(cls, value: bytes) -> float:
return struct.unpack("d", value)[0]
class StringField(Field):
TYPE = str
@classmethod
def encode(cls, value: str) -> bytes:
return value.encode("utf-8")
@classmethod
def decode(cls, value: bytes) -> str:
return value.decode("utf-8")
class DateTimeField(StringField):
TYPE = datetime.datetime
@classmethod
def encode(cls, value: datetime.datetime) -> bytes:
return StringField.encode(value.isoformat())
@classmethod
def decode(cls, value: bytes) -> datetime.datetime:
return datetime.datetime.fromisoformat(StringField.decode(value))
try:
import bson
class BsonField(Field):
TYPE = bson.ObjectId
@classmethod
def encode(cls, value: bson.ObjectId) -> bytes:
return value.binary
@classmethod
def decode(cls, value: bytes) -> bson.ObjectId:
return bson.ObjectId(value)
except ImportError:
pass
class Key(tuple):
def pack(self):
return fdb.tuple.pack(self)
def unpack(self):
return fdb.tuple.unpack(self)
def fdb_field(field: FieldType, *args, **kwargs) -> dataclasses.Field:
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"]["fdb"] = field
return dataclasses.field(*args, **kwargs)
def get_fdb_field(field: FieldType) -> Optional[Field]:
f_field = TYPE_REGISTRY.get(field.type)
default = {"fdb": f_field() if f_field else None}
metadata = field.metadata or default
return metadata.get("fdb")
EMPTY_OBJECT = -2
EMPTY_ARRAY = -1
def to_tuples(item, encode=None):
if item == {}:
return [(EMPTY_OBJECT, None)]
elif item == []:
return [(EMPTY_ARRAY, None)]
elif isinstance(item, dict):
return [(k,) + sub for k, v in item.items() for sub in to_tuples(v)]
elif isinstance(item, (list, tuple)):
return [(k,) + sub for k, v in enumerate(item) for sub in to_tuples(v)]
elif dataclasses.is_dataclass(item):
rv = []
fields = dataclasses.fields(item)
values = dataclasses.asdict(item)
for field in fields:
value = values[field.name]
f_field = get_fdb_field(field)
key = f_field.key if f_field and f_field.key else field.name
for sub in to_tuples(value, f_field.encode):
rv.append((key,) + sub)
return rv
else:
if not encode:
encode = TYPE_REGISTRY.get(type(item)).encode
if not encode:
raise ValueError(f"could not serialize {item}")
return [(encode(item),)]
def from_tuples(tuples: list[tuple]):
if not tuples:
return {}
first = tuples[0]
if len(first) == 1:
return first[0]
if first == (EMPTY_OBJECT, None):
return {}
if first == (EMPTY_ARRAY, None):
return []
groups = [list(g) for k, g in itertools.groupby(tuples, lambda t: t[0])]
if first[0] == 0:
return [from_tuples([t[1:] for t in g]) for g in groups]
else:
return dict((g[0][0], from_tuples([t[1:] for t in g])) for g in groups)
def hydrate_document(storage: dict, doc_class: Any) -> Any:
if not dataclasses.is_dataclass(doc_class):
raise ValueError("doc is not a dataclass")
doc = {}
for field in dataclasses.fields(doc_class):
f_field = get_fdb_field(field)
key = f_field.key if f_field and f_field.key else field.name
if dataclasses.is_dataclass(field.type):
value = hydrate_document(storage[key], field.type)
else:
value = f_field.decode(storage[key])
doc[field.name] = value
return doc_class(**doc)
@fdb.transactional
def store_document(tr: Transaction, space: Subspace, key: tuple, doc: Any):
if not dataclasses.is_dataclass(doc):
raise ValueError("doc is not a dataclass")
for row in to_tuples(doc):
k = space.pack(key + row[:-1])
value = row[-1]
tr[k] = value
@fdb.transactional
def load_document(tr: Transaction, space: Subspace, key: tuple, doc_class: Type) -> Any:
tuples = [space.unpack(k)[1:] + (v,) for k, v in tr[space.range(key)]]
storage = from_tuples(tuples)
doc = hydrate_document(storage, doc_class)
return doc
if __name__ == "__main__":
@dataclasses.dataclass
class SubDoc:
foo: int
bar: str
baz: float
class SomeDict(TypedDict):
aaa: int
bbb: str
ccc: float
@dataclasses.dataclass
class Example:
raw_field: bytes = fdb_field(Field("r"))
bool_field: bool = fdb_field(BoolField("b"))
int_field: int = fdb_field(IntegerField("i"))
float_field: float = fdb_field(FloatField("f"))
str_field: str = fdb_field(StringField("t"))
datetime_field: datetime.datetime = fdb_field(DateTimeField("dt"))
doc_field: SubDoc = fdb_field(Field("sd"))
# FIXME
list_field: list[str] = fdb_field(Field("l"))
dict_field: SomeDict = fdb_field(Field("d"))
example_doc = Example(
raw_field=b'\xde\xad\xbe\xef',
bool_field=True,
int_field=42,
float_field=3.14,
str_field="hello world",
datetime_field=datetime.datetime.now(tz=datetime.timezone.utc),
doc_field=SubDoc(1, "hello", 3.14),
list_field=["a", "b", "c"],
dict_field={"aaa": 1, "bbb": "xxx", "ccc": 2.718},
)
db = fdb.open()
doc_space = Subspace(('D',))
store_document(db, doc_space, ("abc123",), example_doc)
test_doc = load_document(db, doc_space, ("abc123",), Example)
import pprint
pprint.pprint(example_doc)
pprint.pprint(test_doc)
@mdellavo
Copy link
Author

mdellavo commented Jan 15, 2024

Previous output showing example key structure

 ❯ python document.py
[(b'\x02D\x00\x02test\x00\x02r\x00', b'\xde\xad\xbe\xef'),
 (b'\x02D\x00\x02test\x00\x02b\x00', b'\x01'),
 (b'\x02D\x00\x02test\x00\x02i\x00', b'*'),
 (b'\x02D\x00\x02test\x00\x02f\x00', b'\x1f\x85\xebQ\xb8\x1e\t@'),
 (b'\x02D\x00\x02test\x00\x02t\x00', b'hello world'),
 (b'\x02D\x00\x02test\x00\x02dt\x00', b'2024-01-15T16:05:13.273889+00:00'),
 (b'\x02D\x00\x02test\x00\x02sd\x00\x14', b'\x01'),
 (b'\x02D\x00\x02test\x00\x02sd\x00\x15\x01', b'hello'),
 (b'\x02D\x00\x02test\x00\x02sd\x00\x15\x02', b'\x1f\x85\xebQ\xb8\x1e\t@'),
 (b'\x02D\x00\x02test\x00\x02l\x00\x14', b'a'),
 (b'\x02D\x00\x02test\x00\x02l\x00\x15\x01', b'b'),
 (b'\x02D\x00\x02test\x00\x02l\x00\x15\x02', b'c'),
 (b'\x02D\x00\x02test\x00\x02d\x00\x02foo\x00', b'bar')]
[(('test', 'r'), b'\xde\xad\xbe\xef'),
 (('test', 'b'), b'\x01'),
 (('test', 'i'), b'*'),
 (('test', 'f'), b'\x1f\x85\xebQ\xb8\x1e\t@'),
 (('test', 't'), b'hello world'),
 (('test', 'dt'), b'2024-01-15T16:05:13.273889+00:00'),
 (('test', 'sd', 0), b'\x01'),
 (('test', 'sd', 1), b'hello'),
 (('test', 'sd', 2), b'\x1f\x85\xebQ\xb8\x1e\t@'),
 (('test', 'l', 0), b'a'),
 (('test', 'l', 1), b'b'),
 (('test', 'l', 2), b'c'),
 (('test', 'd', 'foo'), b'bar')]

@mdellavo
Copy link
Author

Current output (with broken list, dict fields)

❯ python document.py
Example(raw_field=b'\xde\xad\xbe\xef',
        bool_field=True,
        int_field=42,
        float_field=3.14,
        str_field='hello world',
        datetime_field=datetime.datetime(2024, 1, 16, 4, 51, 32, 265293, tzinfo=datetime.timezone.utc),
        doc_field=SubDoc(foo=1, bar='hello', baz=3.14),
        list_field=['a', 'b', 'c'],
        dict_field={'aaa': 1, 'bbb': 'xxx', 'ccc': 2.718})
Example(raw_field=b'\xde\xad\xbe\xef',
        bool_field=True,
        int_field=42,
        float_field=3.14,
        str_field='hello world',
        datetime_field=datetime.datetime(2024, 1, 16, 4, 51, 32, 265293, tzinfo=datetime.timezone.utc),
        doc_field=SubDoc(foo=1, bar='hello', baz=3.14),
        list_field=[b'a', b'b', b'c'],
        dict_field={'aaa': b'\x01',
                    'bbb': b'xxx',
                    'ccc': b'X9\xb4\xc8v\xbe\x05@'})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment