Skip to content

Instantly share code, notes, and snippets.

@msullivan
Last active September 27, 2022 00:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save msullivan/b45e3c785c55d5d8ea4db196e1d9cb55 to your computer and use it in GitHub Desktop.
Save msullivan/b45e3c785c55d5d8ea4db196e1d9cb55 to your computer and use it in GitHub Desktop.
algorithm for eliminating argument tuples
from __future__ import annotations
from dataclasses import dataclass
import json
import time
import textwrap
from typing import Any, Callable
@dataclass(eq=False)
class Scalar:
name: str = 'T'
def fmt(self) -> str:
return self.name
@dataclass(eq=False)
class Tuple:
typs: tuple[Type, ...]
def fmt(self) -> str:
return f'tuple<{", ".join(x.fmt() for x in self.typs)}>'
@dataclass(eq=False)
class Array:
typ: Type
def fmt(self) -> str:
return f'array<{self.typ.fmt()}>'
@dataclass(eq=False)
class FArray:
typ: Scalar
def fmt(self) -> str:
return f'array<{self.typ.fmt()}>'
Type = Scalar | Tuple | Array
FType = Scalar | FArray
def _array_adjust(typ: Type) -> int:
while isinstance(typ, Tuple):
typ = typ.typs[0]
lmost_is_subarray = isinstance(typ, Array)
return 2 if lmost_is_subarray else 1
# XXX: INVARIANT: NODES IN THE TYPE MUST BE DISTINCT!
Map = dict[Type, tuple[int, bool]]
def translate_type(typ: Type) -> tuple[tuple[FType, ...], Map]:
typs: list[FType] = []
map: Map = {}
def trans(typ: Type, in_array: bool) -> None:
start = len(typs)
if isinstance(typ, Scalar):
nt: FType = FArray(typ) if in_array else typ
typs.append(nt)
elif isinstance(typ, Array):
if in_array:
typs.extend([FArray(Scalar('int64'))])
trans(typ.typ, in_array=True)
elif isinstance(typ, Tuple):
for t in typ.typs:
trans(t, in_array=in_array)
map[typ] = (start, in_array)
trans(typ, in_array=False)
return tuple(typs), map
def encode(typ: Type, ntyps: tuple[FType, ...], map: Map, data: Any) -> tuple[Any, ...]:
args: list[Any] = [0 if isinstance(t, Scalar) else [] for t in ntyps]
def enc(typ: Type, data: Any) -> None:
arg, in_array = map[typ]
if isinstance(typ, Scalar):
if in_array:
args[arg].append(data)
else:
args[arg] = data
elif isinstance(typ, Array):
assert isinstance(data, list)
if in_array:
if not args[arg]:
args[arg].append(0)
args[arg].append(args[arg][-1] + len(data))
for val in data:
enc(typ.typ, val)
elif isinstance(typ, Tuple):
assert isinstance(data, tuple)
assert len(typ.typs) == len(data)
for styp, val in zip(typ.typs, data):
enc(styp, val)
enc(typ, data)
return tuple(args)
def naive_decode(typ: Type, map: Map, data: tuple[Any, ...]) -> Any:
def dec(typ: Type) -> Any:
arg, in_array = map[typ]
if isinstance(typ, Scalar):
return data[arg]
elif isinstance(typ, Array):
parts = dec(typ.typ)
if not in_array:
return parts
out = []
for i in range(len(data[arg]) - 1):
out.append(parts[data[arg][i]:data[arg][i + 1]])
return out
elif isinstance(typ, Tuple):
lparts = [dec(t) for t in typ.typs]
if in_array:
return [tuple(x) for x in zip(*lparts)]
else:
return tuple(lparts)
return dec(typ)
def decode(typ: Type, map: Map, data: tuple[Any, ...]) -> Any:
# I don't think this will *actually* be faster than naive_decode
# (since it will have way more python function calls), but it *is*
# linear time and matches the algorithm used in the edgeql
# decoder.
# Actually, naive_decode is I think also linear time?
# But naive decode does a lot more intermediate object creation.
#
# Actually, actually! Whether naive_decode is linear time depends
# on the object storage model! In a pointer based model it is,
# but in a model where the objects actually get included, it is
# *not*.
# What does Postgres do??
def dec(typ: Type, idx: Optional[int]) -> Any:
arg, _ = map[typ]
if isinstance(typ, Scalar):
return data[arg][idx] if idx is not None else data[arg]
elif isinstance(typ, Array):
if idx is None:
lo = 0
hi = len(data[arg]) - _array_adjust(typ.typ) + 1
else:
lo = data[arg][idx]
hi = data[arg][idx+1]
return [dec(typ.typ, idx=i) for i in range(lo, hi)]
elif isinstance(typ, Tuple):
return tuple(dec(t, idx=idx) for t in typ.typs)
return dec(typ, idx=None)
def make_decoder(typ: Type, map: Map, ftypes: tuple[FType, ...]) -> str:
cnt = 0
def mk_name(x: str) -> str:
nonlocal cnt
cnt += 1
return f'{x}{cnt}'
# BS = '<json>'
BS = ''
def mk(typ: Type, idx: Optional[str]) -> str: # ?
arg, in_array = map[typ]
if isinstance(typ, Scalar):
tname = f'array<{typ.fmt()}>' if in_array else typ.fmt()
if idx is None:
return f'<{tname}>{BS}${arg}'
else:
return f'(<{tname}>{BS}${arg})[{idx}]'
elif isinstance(typ, Array):
a = f'(<array<int64>>${arg})'
# If the contents is just a scalar, then we can take
# values directly from the scalar array parameter, without
# needing to iterate over the array directly.
# This is an optimization, and not necessary for correctness.
if isinstance(typ.typ, Scalar):
sub = mk(typ.typ, idx=None)
# If we are in an array, do a slice!
if idx is not None:
sub = f'({sub})[{a}[{idx}]:{a}[{idx}+1]]'
return sub
inner_idx = mk_name('i')
sub = mk(typ.typ, idx=inner_idx)
if idx is None:
adjust = _array_adjust(typ.typ)
lo = '0'
hi = f'len(<{ftypes[arg].fmt()}>${arg}) - {adjust}'
else:
lo = f'{a}[{idx}]'
hi = f'{a}[{idx}+1]-1'
# lol at this formatting scheme.
grp = textwrap.dedent(f'''\
array_agg((for {inner_idx} in _gen_series({lo}, {hi}) union (
%s
)))'''
) % textwrap.indent(sub, ' ')
return grp
elif isinstance(typ, Tuple):
lparts = [mk(t, idx=idx) + ',' for t in typ.typs]
return f'({" ".join(str(p) for p in lparts)})'
return mk(typ, idx=None)
######### TESTING
def test(t1, data):
ts1, m1 = translate_type(t1)
print(ts1)
print(m1)
v1 = encode(t1, ts1, m1, data)
print(v1)
print()
d1 = decode(t1, m1, v1)
print(d1)
assert d1 == data
print()
print(f'select {make_decoder(t1, m1, ts1)};')
print()
for x in v1:
print(json.dumps(x))
print()
t1 = Array(Tuple((Array(Scalar('str')),)))
t2 = Array(Tuple((t1, Scalar('str'))))
test_data: list[tuple[list[str]]] = [
(['a'],),
(['b','c'],),
(['d','e','f'],),
]
test_data_2p: list[tuple[list[str]]] = [
(['x','y','z','w'],),
(['g','h','i'],),
(['j','k'],),
(['l'],),
]
test_data2 = [(test_data, 'foo'), (test_data_2p, 'bar')]
# simpler
t3 = Array(Tuple((t1,)))
test_data3 = [(test_data,), (test_data_2p,)]
def go():
test(t1, test_data)
test(t2, test_data2)
test(t3, test_data3)
# go()
#################
import hypothesis as h
import hypothesis.strategies as hs
typ = hs.recursive(
hs.builds(lambda _: Scalar('str'), hs.none()),
lambda children: (
hs.builds(Array, children.filter(lambda x: not isinstance(x, Array)))
| hs.lists(children, min_size=1, max_size=8)
.map(lambda l: Tuple(tuple(l)))
)
)
def type_to_strategy(t):
if isinstance(t, Scalar):
# return hs.text(alphabet='abcdefghijklmnoqrstuvwxyz')
return hs.sampled_from('abcdefghijklmnoqrstuvwxyz')
# return hs.integers()
elif isinstance(t, Array):
return hs.lists(type_to_strategy(t.typ))
elif isinstance(t, Tuple):
return hs.tuples(*[type_to_strategy(t) for t in t.typs])
@hs.composite
def type_and_data(draw):
t = draw(typ)
d = draw(type_to_strategy(t))
return (t, d)
@h.given(type_and_data())
def test_encode_decode(td):
t, d = td
print()
print(t.fmt())
print(d)
nts, m = translate_type(t)
encoded = encode(t, nts, m, d)
decoded = decode(t, m, encoded)
assert d == decoded, (
d,
encoded,
decoded,
)
naive_decoded = naive_decode(t, m, encoded)
assert d == naive_decoded, (
d,
encoded,
decoded,
)
print('PASS')
_conn = None
def get_conn():
import edgedb
global _conn
if not _conn:
_conn = edgedb.create_client(
port=5656, tls_security='insecure'
)
return _conn
def _test_edgeql(t, d):
print()
print(t.fmt())
# print(t)
print(d)
nts, m = translate_type(t)
encoded = encode(t, nts, m, d)
query = make_decoder(t, m, nts)
db = get_conn()
print("Q", query)
# print('args', encoded, nts)
tb = time.time()
db.query_single(f'select {query}', *encoded)
t0 = time.time()
decoded = db.query_single(f'select {query}', *encoded)
t1 = time.time()
assert d == decoded, (
d,
encoded,
decoded,
)
# assert t1 - t0 < 1.0 # LOL
print(f'PASS {t0-tb:.3f} {t1-t0:.3f}')
@h.settings(deadline=None)
@h.given(type_and_data())
def test_edgeql(td):
_test_edgeql(*td)
# _test_edgeql(t1, test_data)
test_encode_decode()
test_edgeql()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment