Skip to content

Instantly share code, notes, and snippets.

@jagt
Created July 3, 2023 15:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jagt/cfddc94af31b20873c9b1c0280d3818e to your computer and use it in GitHub Desktop.
Save jagt/cfddc94af31b20873c9b1c0280d3818e to your computer and use it in GitHub Desktop.
# https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html
# it's amazing that rust version is actually shorter, wow
from collections import namedtuple
from enum import Enum
from io import StringIO
class PrattErr(RuntimeError):
def __init__(self, msg):
super().__init__(self, msg)
class TokenType(Enum):
Atom = 0
Op = 1
Eof = 2
Ops = [x for x in '+-*/!?:=[]().']
class Token:
def __init__(self, tt, s):
self.tt = tt
self.s = s
def __repr__(self):
return "<%s '%s'>" % (self.tt, self.s)
def lex(s : str) -> list[Token]:
ls = []
ix = 0
endIx = len(s)
while ix < endIx:
ch = s[ix]
if ch in Ops:
ls.append(Token(TokenType.Op, ch))
ix += 1
continue
elif ch.isalnum():
begin = ix
# python do this is easier since there's no do while and
# there's no increment in while cond
while True:
ix += 1
if ix >= endIx: break
ch = s[ix]
if not ch.isalnum(): break
ls.append(Token(TokenType.Atom, s[begin:ix]))
elif ch.isspace():
ix += 1
else:
raise PrattErr("unexpected %s" % ch)
return ls
class Lexer:
def __init__(self, s):
self.ls = lex(s)
self.ix = 0
self.eof = Token(TokenType.Eof, '\0')
def peek(self):
return self.ls[self.ix] if self.ix < len(self.ls) else self.eof
def advance(self):
self.ix += 1
class NodeAtom:
def __init__(self, token):
self.token = token
def __repr__(self):
return 'NodeAtom: %s' % repr(self.token)
class NodeCons:
def __init__(self, token, ls):
self.token = token
self.ls = ls
def __repr__(self):
sb = StringIO()
_dump_tree(self, sb, 0)
return sb.getvalue()
# def __repr__(self):
# return '%s: %s' % (repr(self.token), repr(self.ls))
def _dump_tree(node, sb, ident):
sb.write(' ' * ident)
sb.write(repr(node.token))
sb.write('\n')
if isinstance(node, NodeCons):
for child in node.ls:
_dump_tree(child, sb, ident+2)
def sexpr(node):
def _sexp(node, sb):
if isinstance(node, NodeAtom):
sb.write(node.token.s)
elif isinstance(node, NodeCons):
sb.write('(%s' % node.token.s)
for child in node.ls:
sb.write(' ')
_sexp(child, sb)
sb.write(')')
sb = StringIO()
_sexp(node, sb)
return sb.getvalue()
def expr(s):
lexer = Lexer(s)
return expr_bp(lexer, 0)
def expr_bp(lexer, min_bp):
'''
the key is the shape:
def expr_bp():
...
while True:
...
parse_expr()
...
}
'''
token = lexer.peek()
lexer.advance() # must advance NOW
lhs = None
if token.tt == TokenType.Atom:
lhs = NodeAtom(token)
elif token.tt == TokenType.Op:
if token.s == '(':
lhs = expr_bp(lexer, 0)
if lexer.peek().s != ')':
raise PrattErr("expect ')' but found %s" % lexer.peek())
lexer.advance()
else:
_, r_bp = prefix_binding_power(token)
rhs = expr_bp(lexer, r_bp)
lhs = NodeCons(token, [rhs])
else:
raise PrattErr('expect Atom get: %s' % token)
while True:
op = lexer.peek()
if op.tt == TokenType.Eof:
break
elif op.tt == TokenType.Op:
pass
else:
raise PrattErr('expect Op get: %s' % token)
ret = postfix_binding_power(op)
if ret:
l_bp, _ = ret
if l_bp < min_bp:
break
lexer.advance()
# indexing
if op.s == '[':
rhs = expr_bp(lexer, 0)
if lexer.peek().s != ']':
raise PrattErr("expect ']' but found %s" % lexer.peek())
lexer.advance()
lhs = NodeCons(op, [lhs, rhs])
else:
lhs = NodeCons(op, [lhs])
continue
ret = infix_binding_power(op)
if ret:
l_bp, r_bp = ret
if l_bp < min_bp:
break
lexer.advance()
if op.s == '?':
mhs = expr_bp(lexer, 0)
if lexer.peek().s != ':':
raise PrattErr("expect ':' but found %s" % lexer.peek())
lexer.advance()
rhs = expr_bp(lexer, 0)
lhs = NodeCons(op, [lhs, mhs, rhs])
else:
rhs = expr_bp(lexer, r_bp)
lhs = NodeCons(op, [lhs, rhs])
continue
break
return lhs
def prefix_binding_power(op):
if op.s in ['+', '-']:
return (None, 9)
raise PrattErr('unexpected prefix op: %s' % op)
def postfix_binding_power(op):
# these are same is ok since it never compare to itself
# it only compares to left
if op.s in ['!', '[']:
return (11, None)
return None
def infix_binding_power(op):
# left assoc, (x, x+1)
if op.s in ['+', '-']:
return (5, 6)
elif op.s in ['*', '/']:
return (7, 8)
# right assoc, (x+1, x)
elif op.s in ['=']:
return (2, 1)
elif op.s in ['?']:
return (4, 3)
elif op.s in ['.']:
return (14, 13)
return None
def test(s, expect):
parsed = sexpr( expr(s) )
if parsed != expect:
print("%s | expect: '%s' get '%s'" % (s, expect, parsed))
if __name__ == '__main__':
test('1', '1')
test('a + b * c', '(+ a (* b c))')
test('a + b * c * d + e', '(+ (+ a (* (* b c) d)) e)')
test('f . g . h', '(. f (. g h))')
test('1 + 2 + f . g . h * 3 * 4', '(+ (+ 1 2) (* (* (. f (. g h)) 3) 4))')
test('--1 * 2', '(* (- (- 1)) 2)')
test('--f . g', '(- (- (. f g)))')
test('-9!', '(- (! 9))')
test('f . g !', '(! (. f g))')
test('(((((1)))))', '1')
test('(a + b) * c', '(* (+ a b) c)')
test('x[0][1]', '([ ([ x 0) 1)')
test('a ? b : c ? d : e', '(? a b (? c d e))')
test('a = b + 23', '(= a (+ b 23))')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment