Skip to content

Instantly share code, notes, and snippets.

@rfong
Last active May 7, 2021 02:52
Show Gist options
  • Save rfong/620cfa7288d2d604fa3ec8c1af7efc9b to your computer and use it in GitHub Desktop.
Save rfong/620cfa7288d2d604fa3ec8c1af7efc9b to your computer and use it in GitHub Desktop.
Lisp AST parser for recurse center interview
from ast.ast import main
if __name__ == '__main__':
main()
#!/usr/bin/env python
'''
Write code that takes some Lisp code and returns an abstract syntax tree.
The AST should represent the structure of the code and the meaning of each
token. For example, if your code is given "(first (list 1 (+ 2 3) 9))",
it could return a nested array like ["first", ["list", 1, ["+", 2, 3], 9]].
'''
import re
import string
import sys
from ast.operators import OPERATORS
# Regexes for extracting quoted strings.
STRING_REGEXES = [re.compile(rexp) for rexp in [
"^'([^']*)'$",
"^\"([^\"]*)\"$",
]]
def main():
'''
Entry point. Currently handles single line expressions.
'''
input = get_input(sys.argv)
print(evaluate(input))
def assert_type(x, expectedType):
if type(x) != expectedType:
raise TypeError(
"Expected %s, but received %s: %s" %
(expectedType.__name__, type(x).__name__, x.__repr__()))
def evaluate(input_str):
assert_type(input_str, str)
ast = parse(input_str)
return _evaluate_ast(ast)
def _evaluate_ast(ast):
'''
Takes a validated AST and evaluates the expression.
'''
# Preprocess expr (recursively collapse it)
for i, expr in enumerate(ast):
if type(expr) == list:
ast[i] = _evaluate_ast(expr)
# We can assume this is a list because it's already been validated
if ast[0] not in OPERATORS:
return ast # pass up to caller
op_fn = OPERATORS.get(ast[0], None)
assert op_fn is not None
return op_fn(*ast[1:])
def parse(lisp_str):
'''
Takes an unvalidated Lisp code input string and returns a validated
parse tree.
'''
assert_type(lisp_str, str)
return _validate(lex(lisp_str))
def _validate(parse_tree):
'''
Takes a nested-list-based parse tree of tokens and converts implicit
values / validates LISP expressions. Recursive.
'''
for i, x in enumerate(parse_tree):
if type(x) == list:
parse_tree[i] = _validate(x)
continue
# Otherwise, must be operator, string, or implicit number
if x in OPERATORS:
continue
# Is it a string?
val = get_quoted_string_value(x)
if val is not None:
parse_tree[i] = val
continue
# Is it an implicit integer?
try:
val = int(x)
parse_tree[i] = val
continue
except ValueError:
pass
# Is it an implicit float?
try:
val = float(x)
parse_tree[i] = val
continue
except ValueError:
pass
# Not parseable; raise a syntax error
raise ValueError("Could not parse '%s'" % x)
return parse_tree
def get_quoted_string_value(s):
for sre in STRING_REGEXES:
m = sre.match(s)
if m is not None and m.group(1) is not None:
return m.group(1)
def lex(lisp_str):
'''
Takes an unvalidated input string and returns a nested-list-based parse
tree.
'''
assert_type(lisp_str, str)
tokens = tokenize(lisp_str)
if len(tokens) > 0 and tokens[0] != '(':
raise SyntaxError("Missing open paren; expr begins with token '%s'" % tokens[0])
parse_tree = _lex_tokens(iter(tokens))
assert len(parse_tree) == 1
return parse_tree[0]
def _lex_tokens(token_iter):
'''
Takes a token iterator and returns a nested-list-based parse tree.
Purely lexical. Does NOT validate tokens.
Quick and dirty recursive approach.
CAVEAT: If worried the input code is nested deep enough to run out of
stack frames (the default is 1000 in Python), switch to a nonrecursive
approach, like keeping track of OOP AST nodes with pointers. (Or cheat
and raise the limit.)
'''
x = None # Current expression or token.
expr = []
while x is not ')':
try:
x = next(token_iter)
# Reached end of tokens.
except StopIteration:
if not expr: # TODO: Should we handle nils?
raise SyntaxError("Unclosed paren")
break
# Start of new expression. Go down a level.
if x == '(':
x = _lex_tokens(token_iter)
expr.append(x)
# Expression ended. Go up a level.
if x == ')':
if len(expr) == 1:
raise SyntaxError("Closing paren does not match an open paren")
return expr[:-1] # Drop close paren
return expr
def tokenize(lisp_str):
'''
Tokenize a Lisp code string. Do not check for validity yet.
Currently handles basic expressions with arbitrary whitespace. Does not
handle comments.
TODO:
+ to support syntax debugger, include the start index of each token
'''
assert_type(lisp_str, str)
tokens = []
token = ''
for ch in lisp_str:
# Whitespace = end of token
if ch in string.whitespace:
if token:
tokens.append(token)
token = ''
elif ch in '()':
if token:
tokens.append(token)
tokens.append(ch)
token = ''
else:
token += ch
return tokens
def get_input(sys_argv):
'''
Grab input from file or interactive terminal
'''
if len(sys.argv) > 1:
try:
with open(sys.argv[1], 'r') as f:
return f.read()
except FileNotFoundError:
print(
'File %s was not found. Switching to manual input.'
% sys.argv[1]
)
return input('Type some Lisp code: ')

Lisp AST parser

I picked this one because I haven't used Lisp before, and I haven't written an AST since school and remember rather enjoying them. (I did write it in a language I know, though.)

Usage

Go up a directory to run, because python module imports are the devil.

Run with interactive input:

python -m ast

Run with file input:

python -m ast ast/test.lisp

Run unit tests:

python -m unittest discover

Potential todos

  • interpreter
  • support multiline expressions
  • syntax debugger with character indices
# Preemptively move this over here so I can write the operator functions
# in a separate file so they don't clog up the main file
from functools import reduce
def plus(*operands):
return sum(operands)
def mult(*operands):
return reduce((lambda x, y: x * y), operands)
def first(*operands):
if len(operands) == 0:
return None
return operands[0]
# Map of operators to functions which operate on operands. TODO
OPERATORS = {
'+': plus,
'*': mult,
'first': first,
# 'list': None,
}
(first (list 1 (+ +2 -3.0) 9))
import unittest
from ast.ast import evaluate
from ast.ast import lex
from ast.ast import parse
from ast.ast import tokenize
class TestAST(unittest.TestCase):
def test_ignore_whitespace(self):
s = " \t(+ -2 ( + 3 9 ) )"
expectedTokens = ['(', '+', '-2', '(', '+', '3', '9', ')', ')']
self.assertEqual(tokenize(s), expectedTokens)
def test_bad_parens(self):
bad_exprs = [
"",
"(",
")",
"(first",
"(2",
"first)",
"2)",
"2))",
"((",
"))",
"(()",
"())",
]
for s in bad_exprs:
with self.assertRaises(
SyntaxError,
msg="input '%s' should have raised a syntax error" % s
) as e:
print("'%s' lexed to:" % s, lex(s))
def test_lex_ok(self):
testCases = [
# (input_string, expected_value)
('(1)', ['1']),
('(first (list 1 (+ 2 -3.0) 9))',
['first', ['list', '1', ['+', '2', '-3.0'], '9']]),
('(+ 1 (* 2 3) 40 (first 3 2938))',
['+', '1', ['*', '2', '3'], '40', ['first', '3', '2938']]),
]
for input, expected in testCases:
self.assertEqual(lex(input), expected)
def test_parse_quoted_string_values(self):
self.assertEqual(parse("('foo')"), ["foo"])
self.assertEqual(parse("(\"foo\")"), ["foo"])
self.assertEqual(parse("(\"'\")"), ["'"])
badValues = ["'''", "\"\"\"", "'", "\""]
for s in badValues:
with self.assertRaises(
ValueError, msg="Bad value %s should have raised exception" % s
) as e:
parse("(%s)" % s)
def test_parse_numbers(self):
self.assertEqual(parse("(5)"), [5])
self.assertEqual(parse("(-2)"), [-2])
self.assertEqual(parse("(3.14159)"), [3.14159])
self.assertEqual(parse("(-3.14159)"), [-3.14159])
badValues = ['123abc', '-123.abc']
for s in badValues:
with self.assertRaises(
ValueError, msg="Bad value %s should have raised exception" % s
) as e:
parse("(%s)" % s)
def test_eval(self):
testCases = [
# (input_string, expected_value)
('(1 2)', [1, 2]),
('(+ 1 2)', 3),
('(+ 1 (* 2 3) 40 (first 3 1 6 9 2938))', 50),
]
for input, expected in testCases:
self.assertEqual(evaluate(input), expected)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment