Skip to content

Instantly share code, notes, and snippets.

@edofic
Created December 9, 2019 16:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edofic/d0252cf9a5d6973c2590a6d35eb64437 to your computer and use it in GitHub Desktop.
Save edofic/d0252cf9a5d6973c2590a6d35eb64437 to your computer and use it in GitHub Desktop.
generator based recursive descent parsers
import doctest
from collections import namedtuple
import sys
def parse_string(target):
"""
>>> list(parse_string('foo')('foobar'))
[('foo', 'bar')]
>>> list(parse_string('foo')('bar'))
[]
"""
def p(s):
if s.startswith(target):
yield (target, s[len(target) :])
return p
def parse_end(s):
"""
>>> list(parse_end(''))
[((), '')]
>>> list(parse_end(' '))
[]
"""
if s == "":
yield ((), "")
def parse_one_of(*parsers):
"""
>>> list(parse_one_of(parse_string('foo'), parse_string('bar'))('foobar'))
[('foo', 'bar')]
>>> list(parse_one_of(parse_string('foo'), parse_string('bar'))('barfoo'))
[('bar', 'foo')]
"""
def p(s):
for parser in parsers:
yield from parser(s)
return p
parse_digit = parse_one_of(*[parse_string(d) for d in "0123456789"])
parse_digit.__doc__ = """
>>> list(parse_digit('7'))
[('7', '')]
"""
def parse_all(*parsers):
"""
>>> list(parse_all(parse_digit, parse_digit)('12'))
[(('1', '2'), '')]
"""
def p(s):
for value, s1 in parsers[0](s):
if len(parsers) == 1:
yield ([value], s1)
else:
for values, s2 in parse_all(*parsers[1:])(s1):
yield (tuple([value] + list(values)), s2)
return p
def parse_map(f, p):
"""
>>> list(parse_map(int, parse_digit)('7'))
[(7, '')]
"""
def p2(s):
for value, s2 in p(s):
yield f(value), s2
return p2
def parse_plus(parser):
"""
>>> list(parse_plus(parse_digit)('7'))
[(['7'], '')]
>>> list(parse_plus(parse_digit)('713a'))
[(['7', '1', '3'], 'a')]
"""
def p(s):
for value, s1 in parser(s):
more = False
if s1:
for values, s2 in p(s1):
more = True
yield [value] + values, s2
if not more:
yield [value], s1
return p
Sum = namedtuple("Sum", ["n", "m"])
Product = namedtuple("Product", ["n", "m"])
parse_num = parse_map(lambda digits: int("".join(digits)), parse_plus(parse_digit))
parse_num.__doc__ = """
>>> list(parse_num('123'))
[(123, '')]
"""
def parse_product(s):
"""
>>> list(parse_product('1*2'))
[(Product(n=1, m=2), '')]
>>> list(parse_product('1*2*3'))
[(Product(n=1, m=2), '*3'), (Product(n=1, m=Product(n=2, m=3)), '')]
"""
return parse_map(
lambda vs: Product(vs[0], vs[2]),
parse_all(parse_num, parse_string("*"), parse_one_of(parse_num, parse_product)),
)(s)
def parse_sum(s):
"""
>>> list(parse_sum('1+2'))
[(Sum(n=1, m=2), '')]
>>> list(parse_sum('1+2+3'))
[(Sum(n=1, m=2), '+3'), (Sum(n=1, m=Sum(n=2, m=3)), '')]
>>> list(parse_sum('1+2*3'))
[(Sum(n=1, m=Product(n=2, m=3)), ''), (Sum(n=1, m=2), '*3')]
"""
summand = parse_one_of(parse_product, parse_num)
return parse_map(
lambda vs: Sum(vs[0], vs[2]),
parse_all(summand, parse_string("+"), parse_one_of(summand, parse_sum)),
)(s)
parse_expr = parse_map(
lambda vs: vs[0],
parse_all(parse_one_of(parse_sum, parse_product, parse_num), parse_end),
)
parse_expr.__doc__ = """
>>> list(parse_expr('1'))
[(1, '')]
>>> list(parse_expr('1*2'))
[(Product(n=1, m=2), '')]
>>> list(parse_expr('1+2'))
[(Sum(n=1, m=2), '')]
>>> list(parse_expr('1+2*3'))
[(Sum(n=1, m=Product(n=2, m=3)), '')]
"""
def eval_expr(e):
"""
>>> eval_expr(Sum(1, Product(2, 3)))
7
"""
if isinstance(e, int):
return e
elif isinstance(e, Sum):
return eval_expr(e.n) + eval_expr(e.m)
elif isinstance(e, Product):
return eval_expr(e.n) * eval_expr(e.m)
else:
raise ValueError("{}:{} must be a int|Sum|Product tree".format(e, type(e)))
def run_repl():
while True:
print('\nEnter expression to evaluate, "q" to quit')
print("> ", end="")
raw = input()
stripped = raw.replace(" ", "")
if stripped == "q":
return
print(stripped)
gen = parse_expr(stripped)
expr, rem = next(gen)
print(expr, rem)
res = eval_expr(expr)
print(res)
def print_usage():
print(f"Usage: {sys.argv[0]} test|repl\n")
if __name__ == "__main__":
if len(sys.argv) != 2:
print_usage()
else:
if sys.argv[1] == "test":
doctest.testmod()
elif sys.argv[1] == "repl":
run_repl()
else:
print_usage()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment