Created
March 10, 2023 18:15
-
-
Save Sekenre/c43515f75643642ff1f4a94eb819a9de to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Inspired by: https://github.com/dabeaz/blog/blob/main/2023/three-problems.md | |
""" | |
import datetime as dt | |
from enum import IntEnum | |
import struct | |
from dataclasses import dataclass | |
class Major(IntEnum): | |
POSINT = 0 | |
NEGINT = 1 | |
BYTES = 2 | |
STR = 3 | |
ARRAY = 4 | |
MAP = 5 | |
TAG = 6 | |
SPECIAL = 7 | |
@dataclass | |
class Token: | |
major: int | |
subtype: int | |
payload: memoryview | |
@dataclass | |
class Tag: | |
tag: int | |
value: ... | |
@dataclass | |
class Simple: | |
value: int | |
BREAK_MARKER = Token(Major.SPECIAL, subtype=31, payload=None) | |
# EOF_MARKER = Token(None, None, None) | |
class undefined: | |
pass | |
class Lexer: | |
def __init__(self, buf: bytes): | |
self._view = memoryview(buf) | |
self._pos = 0 | |
self.decoders = { | |
Major.POSINT: self.decode_subtype, | |
Major.NEGINT: self.decode_subtype, | |
Major.BYTES: self.decode_bytes, | |
Major.STR: self.decode_bytes, | |
Major.ARRAY: self.decode_subtype, | |
Major.MAP: self.decode_subtype, | |
Major.TAG: self.decode_subtype, | |
Major.SPECIAL: self.decode_special, | |
} | |
def _decode(self): | |
ib = self._view[self._pos] | |
maj = Major(ib >> 5) | |
subtype = ib & 31 | |
self._pos += 1 | |
return self.decoders[maj](maj, subtype) | |
def _decode_all(self): | |
while self._pos < len(self._view): | |
try: | |
yield self._decode() | |
except IndexError: | |
return | |
def __iter__(self): | |
return self._decode_all() | |
def decode_subtype(self, major, subtype): | |
return Token(major, self._decode_length(subtype), None) | |
def decode_bytes(self, major, subtype): | |
length = self._decode_length(subtype) | |
start, end = self._pos, self._pos + length | |
self._pos = end | |
return Token(major, length, self._view[start : end]) | |
def decode_special(self, major, subtype): | |
p, v = self._pos, self._view | |
if subtype < 24 or subtype == 31: | |
return Token(major, subtype, None) | |
elif subtype == 24: | |
self._pos = p + 1 | |
return Token(major, subtype, v[p : p + 1]) | |
elif 25 <= subtype < 28: | |
length = 1 << (subtype & 7) | |
self._pos = p + length | |
return Token(major, subtype, v[p : p + length]) | |
else: | |
raise ValueError("bad special value type %d" % subtype) | |
def _decode_length(self, subtype): | |
p, v = self._pos, self._view | |
if subtype < 24: | |
return subtype | |
elif subtype == 24: | |
self._pos = p + 1 | |
return v[p] | |
elif subtype == 25: | |
self._pos = p + 2 | |
return struct.unpack(">H", v[p : p + 2])[0] | |
elif subtype == 26: | |
self._pos = p + 4 | |
return struct.unpack(">L", v[p : p + 4])[0] | |
elif subtype == 27: | |
self._pos = p + 8 | |
return struct.unpack(">Q", v[p : p + 8])[0] | |
elif subtype == 31: | |
return None | |
else: | |
raise ValueError("unknown unsigned integer subtype 0x%x" % subtype) | |
# Lambda parser | |
def shift(inp): | |
text, n = inp | |
return n < len(text) and (text[n], (text, n + 1)) | |
def seq(*parsers): | |
def parse(inp): | |
result = [] | |
for p in parsers: | |
if not (m := p(inp)): | |
return False | |
value, inp = m | |
result.append(value) | |
return (result, inp) | |
return parse | |
# return from the first parser that returns true | |
def choice(*parsers): | |
def parse(inp): | |
for p in parsers: | |
if m := p(inp): | |
return m | |
return False | |
return parse | |
def one_or_more(parser): | |
def parse(inp): | |
result = [] | |
while m := parser(inp): | |
value, inp = m | |
result.append(value) | |
return bool(result) and (result, inp) | |
return parse | |
def counter(countp, parser): | |
def parse(inp): | |
result = [] | |
if not (m := countp(inp)): | |
return False | |
count, inp = m | |
for c in range(count): | |
m = parser(inp) | |
if not m: | |
return False | |
value, inp = m | |
result.append(value) | |
return bool(result) and (result, inp) | |
return parse | |
filt = lambda predicate: ( | |
lambda parser: lambda inp: (m := parser(inp)) and predicate(m[0]) and m | |
) | |
literal = lambda value: filt(lambda v: v==value) | |
char = lambda value: literal(value)(shift) | |
fmap = lambda func: ( | |
lambda parser: lambda inp: (m := parser(inp)) and (func(m[0]), m[1]) | |
) | |
either = lambda p1, p2: (lambda inp: p1(inp) or p2(inp)) | |
right = lambda p1, p2: fmap(lambda p: p[1])(seq(p1, p2)) | |
left = lambda p1, p2: fmap(lambda p: p[0])(seq(p1, p2)) | |
repeat = lambda count, parser: seq(*([parser] * count)) | |
nothing = lambda inp: (None, inp) | |
zero = lambda inp: (0, inp) | |
maybe = lambda parser: either(parser, nothing) | |
maybe_zero = lambda parser: either(parser, zero) | |
# Allows parser to be defined out of order | |
forward = lambda pthunk: lambda inp: pthunk()(inp) # From Dave Beazley | |
# Date string parsing | |
dash = char('-') | |
colon = char(":") | |
dot = char(".") | |
digit = filt(str.isdigit)(shift) | |
numeric = lambda count: fmap(int)(fmap(''.join)(repeat(count, digit))) | |
year = numeric(4) | |
date_str = seq(left(year, dash), left(numeric(2), dash), numeric(2)) | |
date = fmap(lambda p: dt.date(*p))(date_str) | |
zulu = fmap(lambda d: dt.timezone.utc)(char('Z')) | |
offset_str = seq(either(char('+'), dash), numeric(2), right(colon, numeric(2))) | |
sign = lambda s,p : p * {"+": 1, "-": -1}[s] | |
offset = fmap(lambda p: dt.timezone(sign(p[0], dt.timedelta(hours=p[1], minutes=p[2]))))(offset_str) | |
zoneinfo = either(zulu, offset) | |
to_usec = lambda d: int(''.join(d[:6]).ljust(6, "0")) | |
usec = fmap(to_usec)(one_or_more(digit)) | |
time_frac = maybe_zero(right(dot, usec)) | |
time_str = seq(left(numeric(2), colon), left(numeric(2), colon), numeric(2), time_frac, maybe(zoneinfo)) | |
time = fmap(lambda p: dt.time(*p))(time_str) | |
datetime = fmap(lambda p: dt.datetime.combine(*p))(seq(date, maybe(right(char('T'), time)))) | |
datetime_parser = lambda p: datetime(Input(p))[0] | |
# CBOR specific expect_ functions | |
expect_subtype = lambda ty: fmap(lambda tok: tok.subtype)( | |
filt(lambda tok: tok.major == ty)(shift) | |
) | |
expect_payload = lambda ty: fmap(lambda tok: tok.payload)( | |
filt(lambda tok: tok.major == ty and tok.subtype == len(tok.payload))(shift) | |
) | |
expect_special = lambda cmp_st: filt( | |
lambda tok: tok.major == Major.SPECIAL and cmp_st(tok.subtype) | |
)(shift) | |
# Major datatypes | |
posint = expect_subtype(Major.POSINT) | |
negint = fmap(lambda v: -v - 1)(expect_subtype(Major.NEGINT)) | |
integer = either(posint, negint) | |
bstring = fmap(lambda p: p.tobytes())(expect_payload(Major.BYTES)) | |
ustring = fmap(lambda p: p.tobytes().decode("utf-8"))(expect_payload(Major.STR)) | |
# Special items | |
simple = fmap(lambda t: Simple(t.subtype))(expect_special(lambda t: t < 20)) | |
just = lambda item, tag_id: fmap(lambda t: item)(expect_special(lambda t: t == tag_id)) | |
booleans = choice(just(False, 20), just(True, 21), just(None, 22), just(undefined, 23)) | |
sv2 = fmap(lambda t: Simple(struct.unpack("B", t.payload)[0]))( | |
expect_special(lambda t: t == 24) | |
) | |
funpack = lambda ty, tag_id: fmap(lambda t: struct.unpack(ty, t.payload)[0])( | |
expect_special(lambda t: t == tag_id) | |
) | |
f16 = funpack(">e", 25) | |
f32 = funpack(">f", 26) | |
f64 = funpack(">d", 27) | |
floats = choice(f16, f32, f64) | |
specials = choice(simple, booleans, sv2, floats) | |
# recursive containers | |
atom = forward(lambda: choice(integer, bstring, ustring, specials, array, mapping, tagged)) | |
array = counter(expect_subtype(Major.ARRAY), atom) | |
mapping = fmap(dict)(counter(expect_subtype(Major.MAP), seq(atom, atom))) | |
# Tags | |
tag = seq(expect_subtype(Major.TAG), atom) | |
tagobj = fmap(lambda t: Tag(*t))(tag) | |
expect_tag = lambda tag_id: fmap(lambda t: t.value)(filt(lambda t: t.tag == tag_id)(tagobj)) | |
datetime_string = fmap(datetime_parser)(expect_tag(0)) | |
timestamp = fmap(lambda t: dt.datetime.fromtimestamp(t, tz=dt.timezone.utc))( | |
expect_tag(1) | |
) | |
tagged = choice(datetime_string, timestamp, tagobj) | |
cbor_seq = one_or_more(atom) | |
Input = lambda vs: (vs, 0) | |
if __name__ == "__main__": | |
import cbor2 | |
items = [] | |
things = [ | |
1, | |
-100, | |
b"h", | |
"hi", | |
[1, -100, [50, 20]], | |
{1: -100, 2: {"a": "B"}}, | |
cbor2.CBORTag(1, 1), | |
dt.datetime.now(tz=dt.timezone.utc), | |
0.6, | |
True, | |
False, | |
None, | |
cbor2.CBORSimpleValue(1), | |
cbor2.CBORSimpleValue(254), | |
cbor2.undefined, | |
cbor2.CBORTag(0xbeefcafe, [1, cbor2.CBORTag(1,2)]), | |
] | |
for thing in things: | |
items.append(cbor2.dumps(thing)) | |
payload = b"".join(items) | |
l = list(Lexer(payload)) | |
result, remainder = cbor_seq(Input(l)) | |
print(result) | |
print(remainder) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment