Skip to content

Instantly share code, notes, and snippets.

@Sekenre
Created March 10, 2023 18:15
Show Gist options
  • Save Sekenre/c43515f75643642ff1f4a94eb819a9de to your computer and use it in GitHub Desktop.
Save Sekenre/c43515f75643642ff1f4a94eb819a9de to your computer and use it in GitHub Desktop.
"""
Inspired by: https://github.com/dabeaz/blog/blob/main/2023/three-problems.md
"""
import datetime as dt
from enum import IntEnum
import struct
from dataclasses import dataclass
class Major(IntEnum):
POSINT = 0
NEGINT = 1
BYTES = 2
STR = 3
ARRAY = 4
MAP = 5
TAG = 6
SPECIAL = 7
@dataclass
class Token:
major: int
subtype: int
payload: memoryview
@dataclass
class Tag:
tag: int
value: ...
@dataclass
class Simple:
value: int
BREAK_MARKER = Token(Major.SPECIAL, subtype=31, payload=None)
# EOF_MARKER = Token(None, None, None)
class undefined:
pass
class Lexer:
def __init__(self, buf: bytes):
self._view = memoryview(buf)
self._pos = 0
self.decoders = {
Major.POSINT: self.decode_subtype,
Major.NEGINT: self.decode_subtype,
Major.BYTES: self.decode_bytes,
Major.STR: self.decode_bytes,
Major.ARRAY: self.decode_subtype,
Major.MAP: self.decode_subtype,
Major.TAG: self.decode_subtype,
Major.SPECIAL: self.decode_special,
}
def _decode(self):
ib = self._view[self._pos]
maj = Major(ib >> 5)
subtype = ib & 31
self._pos += 1
return self.decoders[maj](maj, subtype)
def _decode_all(self):
while self._pos < len(self._view):
try:
yield self._decode()
except IndexError:
return
def __iter__(self):
return self._decode_all()
def decode_subtype(self, major, subtype):
return Token(major, self._decode_length(subtype), None)
def decode_bytes(self, major, subtype):
length = self._decode_length(subtype)
start, end = self._pos, self._pos + length
self._pos = end
return Token(major, length, self._view[start : end])
def decode_special(self, major, subtype):
p, v = self._pos, self._view
if subtype < 24 or subtype == 31:
return Token(major, subtype, None)
elif subtype == 24:
self._pos = p + 1
return Token(major, subtype, v[p : p + 1])
elif 25 <= subtype < 28:
length = 1 << (subtype & 7)
self._pos = p + length
return Token(major, subtype, v[p : p + length])
else:
raise ValueError("bad special value type %d" % subtype)
def _decode_length(self, subtype):
p, v = self._pos, self._view
if subtype < 24:
return subtype
elif subtype == 24:
self._pos = p + 1
return v[p]
elif subtype == 25:
self._pos = p + 2
return struct.unpack(">H", v[p : p + 2])[0]
elif subtype == 26:
self._pos = p + 4
return struct.unpack(">L", v[p : p + 4])[0]
elif subtype == 27:
self._pos = p + 8
return struct.unpack(">Q", v[p : p + 8])[0]
elif subtype == 31:
return None
else:
raise ValueError("unknown unsigned integer subtype 0x%x" % subtype)
# Lambda parser
def shift(inp):
text, n = inp
return n < len(text) and (text[n], (text, n + 1))
def seq(*parsers):
def parse(inp):
result = []
for p in parsers:
if not (m := p(inp)):
return False
value, inp = m
result.append(value)
return (result, inp)
return parse
# return from the first parser that returns true
def choice(*parsers):
def parse(inp):
for p in parsers:
if m := p(inp):
return m
return False
return parse
def one_or_more(parser):
def parse(inp):
result = []
while m := parser(inp):
value, inp = m
result.append(value)
return bool(result) and (result, inp)
return parse
def counter(countp, parser):
def parse(inp):
result = []
if not (m := countp(inp)):
return False
count, inp = m
for c in range(count):
m = parser(inp)
if not m:
return False
value, inp = m
result.append(value)
return bool(result) and (result, inp)
return parse
filt = lambda predicate: (
lambda parser: lambda inp: (m := parser(inp)) and predicate(m[0]) and m
)
literal = lambda value: filt(lambda v: v==value)
char = lambda value: literal(value)(shift)
fmap = lambda func: (
lambda parser: lambda inp: (m := parser(inp)) and (func(m[0]), m[1])
)
either = lambda p1, p2: (lambda inp: p1(inp) or p2(inp))
right = lambda p1, p2: fmap(lambda p: p[1])(seq(p1, p2))
left = lambda p1, p2: fmap(lambda p: p[0])(seq(p1, p2))
repeat = lambda count, parser: seq(*([parser] * count))
nothing = lambda inp: (None, inp)
zero = lambda inp: (0, inp)
maybe = lambda parser: either(parser, nothing)
maybe_zero = lambda parser: either(parser, zero)
# Allows parser to be defined out of order
forward = lambda pthunk: lambda inp: pthunk()(inp) # From Dave Beazley
# Date string parsing
dash = char('-')
colon = char(":")
dot = char(".")
digit = filt(str.isdigit)(shift)
numeric = lambda count: fmap(int)(fmap(''.join)(repeat(count, digit)))
year = numeric(4)
date_str = seq(left(year, dash), left(numeric(2), dash), numeric(2))
date = fmap(lambda p: dt.date(*p))(date_str)
zulu = fmap(lambda d: dt.timezone.utc)(char('Z'))
offset_str = seq(either(char('+'), dash), numeric(2), right(colon, numeric(2)))
sign = lambda s,p : p * {"+": 1, "-": -1}[s]
offset = fmap(lambda p: dt.timezone(sign(p[0], dt.timedelta(hours=p[1], minutes=p[2]))))(offset_str)
zoneinfo = either(zulu, offset)
to_usec = lambda d: int(''.join(d[:6]).ljust(6, "0"))
usec = fmap(to_usec)(one_or_more(digit))
time_frac = maybe_zero(right(dot, usec))
time_str = seq(left(numeric(2), colon), left(numeric(2), colon), numeric(2), time_frac, maybe(zoneinfo))
time = fmap(lambda p: dt.time(*p))(time_str)
datetime = fmap(lambda p: dt.datetime.combine(*p))(seq(date, maybe(right(char('T'), time))))
datetime_parser = lambda p: datetime(Input(p))[0]
# CBOR specific expect_ functions
expect_subtype = lambda ty: fmap(lambda tok: tok.subtype)(
filt(lambda tok: tok.major == ty)(shift)
)
expect_payload = lambda ty: fmap(lambda tok: tok.payload)(
filt(lambda tok: tok.major == ty and tok.subtype == len(tok.payload))(shift)
)
expect_special = lambda cmp_st: filt(
lambda tok: tok.major == Major.SPECIAL and cmp_st(tok.subtype)
)(shift)
# Major datatypes
posint = expect_subtype(Major.POSINT)
negint = fmap(lambda v: -v - 1)(expect_subtype(Major.NEGINT))
integer = either(posint, negint)
bstring = fmap(lambda p: p.tobytes())(expect_payload(Major.BYTES))
ustring = fmap(lambda p: p.tobytes().decode("utf-8"))(expect_payload(Major.STR))
# Special items
simple = fmap(lambda t: Simple(t.subtype))(expect_special(lambda t: t < 20))
just = lambda item, tag_id: fmap(lambda t: item)(expect_special(lambda t: t == tag_id))
booleans = choice(just(False, 20), just(True, 21), just(None, 22), just(undefined, 23))
sv2 = fmap(lambda t: Simple(struct.unpack("B", t.payload)[0]))(
expect_special(lambda t: t == 24)
)
funpack = lambda ty, tag_id: fmap(lambda t: struct.unpack(ty, t.payload)[0])(
expect_special(lambda t: t == tag_id)
)
f16 = funpack(">e", 25)
f32 = funpack(">f", 26)
f64 = funpack(">d", 27)
floats = choice(f16, f32, f64)
specials = choice(simple, booleans, sv2, floats)
# recursive containers
atom = forward(lambda: choice(integer, bstring, ustring, specials, array, mapping, tagged))
array = counter(expect_subtype(Major.ARRAY), atom)
mapping = fmap(dict)(counter(expect_subtype(Major.MAP), seq(atom, atom)))
# Tags
tag = seq(expect_subtype(Major.TAG), atom)
tagobj = fmap(lambda t: Tag(*t))(tag)
expect_tag = lambda tag_id: fmap(lambda t: t.value)(filt(lambda t: t.tag == tag_id)(tagobj))
datetime_string = fmap(datetime_parser)(expect_tag(0))
timestamp = fmap(lambda t: dt.datetime.fromtimestamp(t, tz=dt.timezone.utc))(
expect_tag(1)
)
tagged = choice(datetime_string, timestamp, tagobj)
cbor_seq = one_or_more(atom)
Input = lambda vs: (vs, 0)
if __name__ == "__main__":
import cbor2
items = []
things = [
1,
-100,
b"h",
"hi",
[1, -100, [50, 20]],
{1: -100, 2: {"a": "B"}},
cbor2.CBORTag(1, 1),
dt.datetime.now(tz=dt.timezone.utc),
0.6,
True,
False,
None,
cbor2.CBORSimpleValue(1),
cbor2.CBORSimpleValue(254),
cbor2.undefined,
cbor2.CBORTag(0xbeefcafe, [1, cbor2.CBORTag(1,2)]),
]
for thing in things:
items.append(cbor2.dumps(thing))
payload = b"".join(items)
l = list(Lexer(payload))
result, remainder = cbor_seq(Input(l))
print(result)
print(remainder)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment