Skip to content

Instantly share code, notes, and snippets.

@suica
Created February 4, 2021 04:46
Show Gist options
  • Save suica/575a0d7065ec811678a0a20f4b1d0f0b to your computer and use it in GitHub Desktop.
Save suica/575a0d7065ec811678a0a20f4b1d0f0b to your computer and use it in GitHub Desktop.
A simplistic Scheme interpreter
from functools import reduce
from typing import List, Tuple
import operator
class SchemeList:
next = None
def __init__(self, value):
self.value = value
def get_nested(self):
return [self.get_value(),
self.get_next().get_nested() if isinstance(self.get_next(), SchemeList) else self.get_next()]
def get_next(self):
return self.next
def get_value(self):
return self.value
def __repr__(self):
return str(self.get_nested())
@staticmethod
def cons(a, b):
root = SchemeList(a)
root.next = b
return root
class SchemeQuoteList:
def __init__(self, value):
self.value = value
def __repr__(self):
return "quote".format(self.value)
def tokenize(s: str) -> List[str]:
tokens = []
current_token = ''
for i, c in enumerate(s):
if c == '(' or c == ')':
tokens.append(current_token)
current_token = ''
tokens.append(c)
elif c == ' ':
tokens.append(current_token)
current_token = ''
elif '0' <= c <= '9' or c == '.':
current_token += c
elif c in '\t\n':
continue
else:
current_token += c
tokens.append(current_token)
return [token for token in tokens if token != '']
def parse(tokens: List[str], start=0):
tree = []
index = start
while index < len(tokens):
current_token = tokens[index]
if current_token == ')':
return tree, index
elif current_token == '(':
subtree, end_i = parse(tokens, start=index + 1)
tree.append(tuple(subtree))
index = end_i + 1
continue
elif current_token == "'":
subtree, end_i = parse(tokens, start=index + 1)
tree.append((SchemeQuoteList(
('list',) + tuple(subtree[0]),
)))
index = end_i
continue
else:
tree.append(current_token)
index += 1
return tree, index
def apply(expr: Tuple, context):
if len(expr) > 0:
op = expr[0]
if isinstance(op, str):
is_special_form = op in ['define', 'if', 'lambda', 'cond', 'and', 'or', 'list', 'car', 'cdr']
if is_special_form:
# 特殊型
operands = expr[1:]
if op == 'define':
middle, *body = operands
if len(body) != 1:
raise NotImplementedError
body = body[0]
if isinstance(middle, tuple):
# 是函数定义
name, *parameters = middle
context[name] = apply(tuple(['lambda', parameters, body]), context.copy())
else:
# 是常量定义
name = middle
context[name] = evaluate(body, context)
return 'def:{}'.format(name)
elif op == 'if':
if len(operands) != 3:
raise Exception('if expression is malformed')
predicate, consequent, alternative = operands
if evaluate(predicate, context) is True:
return evaluate(consequent, context)
else:
return evaluate(alternative, context)
raise NotImplementedError
elif op == 'cond':
raise NotImplementedError
elif op == 'lambda':
if len(operands) != 2:
raise Exception('lambda expression is malformed: {}'.format(expr))
[*parameters], body = operands
new_context = context.copy()
for p in parameters:
new_context[p] = None
def _lambda(*args):
if len(args) != len(parameters):
raise Exception('arity error, expect {} but given {}'.format(len(parameters), len(args)))
for p, a in zip(parameters, args):
new_context[p] = a
if isinstance(body, tuple):
return apply(body, new_context)
return evaluate(body, new_context)
return _lambda
elif op == 'and':
for unevaluated in operands:
temp = evaluate(unevaluated, context)
if temp is False:
return False
return True
elif op == 'or':
for unevaluated in operands:
temp = evaluate(unevaluated, context)
if temp is True:
return True
return False
elif op == 'list':
root = SchemeList(None)
cur = root
for operand in operands:
cur.next = SchemeList(evaluate(operand, context))
cur = cur.next
return root.next
elif op in ['car', 'cdr']:
if len(operands) != 1:
raise Exception("arity error for car/cdr")
operand = evaluate(operands[0], context)
if isinstance(operand, SchemeQuoteList):
operand = evaluate(operand.value, context)
if isinstance(operand, SchemeList):
return operand.get_value() if op == 'car' else operand.get_next()
raise Exception('car/cdr error: {} is not a list'.format(operand))
else:
# 这意味着, op是个函数
# 我们可以立即求值其所有参数
operands = [evaluate(operand, context) for operand in expr[1:]]
if op == '+':
return sum(operands)
elif op == '-':
return operands[0] - sum(operands[1:])
elif op == '*':
return reduce(lambda pre, cur: pre * cur, operands, 1)
elif op == '/':
raise NotImplementedError
elif op in context:
func = context[op]
if callable(func):
return func(*operands)
else:
raise Exception('{} is not callable'.format(op))
elif isinstance(op, tuple):
operands = [evaluate(operand, context) for operand in expr[1:]]
func = apply(op, context)
try:
return func(*operands)
except Exception as e:
print(e)
raise TypeError
raise Exception('unrecognised form {}'.format(op))
else:
raise Exception('empty expression to apply')
def evaluate(expr, context) -> List:
if isinstance(expr, list):
if len(expr) == 0:
raise Exception('attempt to evaluate an empty expression')
results = []
for item in expr:
if isinstance(item, tuple):
results.append(apply(item, context))
else:
results.append(evaluate(item, context))
return results
elif isinstance(expr, tuple):
return apply(expr, context)
elif isinstance(expr, str):
try:
return int(expr)
except:
if expr in context:
return context[expr]
raise Exception("{} is not in context {}".format(expr, context))
elif isinstance(expr, SchemeQuoteList):
return expr
raise TypeError(expr)
def format_result(lis):
result = []
for item in lis:
if isinstance(item, tuple):
raise NotImplementedError
elif callable(item):
result.append('lambda')
else:
result.append(item)
return result
def equal(list_a: SchemeList, list_b: SchemeList):
if list_a is None and list_b is None:
return True
if list_a is None or list_b is None:
return False
if isinstance(list_a, SchemeList) and isinstance(list_a, SchemeList):
return (list_a.get_value() == list_b.get_value()) and equal(list_a.get_next(), list_b.get_next())
return list_a == list_b
predefined_context = {
'true': True,
'false': False,
'<': operator.lt,
'<=': operator.le,
'>': operator.gt,
'>=': operator.ge,
'=': operator.eq,
'not': lambda x: not x,
'null?': lambda x: x is None,
'equal?': equal,
'cons': lambda a, b: SchemeList.cons(a, b),
'nil': None,
}
# entry
def tokenize_and_evaluate(s):
return format_result(evaluate(parse(tokenize(s))[0], predefined_context))
import unittest
from unittest import skip
import scheme_interpreter
class SchemeInterpreterTestCases(unittest.TestCase):
def test_tokenize(self):
cases = [
['0', ['0']],
['(0)', ['(', '0', ')']],
['(+ 1 (- 1 2))', ['(', '+', '1', '(', '-', '1', '2', ')', ')']],
['(+ 1 22222)', ['(', '+', '1', '22222', ')']],
["' (1 2 )", ["'", '(', '1', '2', ')']]
]
for _input, expected in cases:
result = scheme_interpreter.tokenize(_input)
self.assertEqual(expected, result)
def test_parse(self):
self.assertEqual(['0', '1', '2'], scheme_interpreter.parse(scheme_interpreter.tokenize('0 1 2'))[0])
self.assertEqual([('+', '91', '9')], scheme_interpreter.parse(scheme_interpreter.tokenize('(+ 91 9)'))[0])
self.assertEqual([('+', '91', ('*', '3', '3'))],
scheme_interpreter.parse(scheme_interpreter.tokenize('(+ 91 (* 3 3))'))[0])
self.assertEqual([('+', '91', ('*', '3', '3')), ('+', '1', '2')],
scheme_interpreter.parse(scheme_interpreter.tokenize('(+ 91 (* 3 3)) (+ 1 2)'))[0])
parsed = (scheme_interpreter.parse(scheme_interpreter.tokenize(
"""
(car (list 1))
(car (list 1))
"""
)))[0]
self.assertEqual(2, len(parsed))
def test_parse_with_quote(self):
parsed = (scheme_interpreter.parse(scheme_interpreter.tokenize(
"""
(car '(1))
"""
)))[0]
self.assertEqual(1, len(parsed))
parsed = (scheme_interpreter.parse(scheme_interpreter.tokenize(
"""
(car '(1))
(car '(1))
"""
)))[0]
self.assertEqual(2, len(parsed))
self.assertEqual(2, len(parsed[0]))
self.assertEqual(2, len(parsed[1]))
def test_evaluate(self):
cases = [
['0', [0]],
['0 114514 2', [0, 114514, 2]],
['(+ 1 2)', [3]],
['(* 100 (+ 1 2))', [300]],
['(+ (* 3 (+ (* 2 4) (+ 3 5))) (+ (- 10 7) 6))', [57]]
]
for _input, expected in cases:
result = scheme_interpreter.tokenize_and_evaluate(_input)
self.assertEqual(expected, result)
def test_lambda(self):
cases = [
['((lambda (f x) (+ x f)) 1 2)', [3]],
['((lambda () 1))', [1]],
['(lambda () 1)', ['lambda']],
['((lambda (wocao ssd c) (+ ssd 3)) 1 2 3)', [5]]
]
for _input, expected in cases:
result = scheme_interpreter.tokenize_and_evaluate(_input)
self.assertEqual(expected, result)
def test_define(self):
cases = [
[
'(define (square x) (* x x)) (square 10)',
['def:square', 100]
],
[
'(define money 1) money',
['def:money', 1]
],
[
'(define (square x) (* x x))',
['def:square']
],
[
'((lambda () 1))',
[1]
],
['((lambda (wocao ssd c) (+ ssd 3)) 1 2 3)', [5]],
[
'(define pi 314) ((lambda (x) (* 2 x)) pi)',
['def:pi', 628]
],
['(define (infinite_loop x) (infinite_loop))', ['def:infinite_loop']],
['(define (infinite_loop) (infinite_loop))', ['def:infinite_loop']]
]
for _input, expected in cases:
result = scheme_interpreter.tokenize_and_evaluate(_input)
self.assertEqual(expected, result)
def test_multiline_program(self):
cases = [
[
"(define (double x) (* x 2)) (double 2333)",
['def:double', 4666]
],
[
"""
( define pi 314)
((lambda ( x) ( * 2 x)) pi)
""",
[
'def:pi',
628
]
],
[
"""
(define ( double
x) ( * x 2 ))
(double 2333)
(define pi 314)
((lambda (x) (* 2 x)) pi)
""",
['def:double', 4666, 'def:pi', 628]
],
[
"""
(define pi 314) \n\n\n ((\n lambda\n (x)\n (*\n 2\n x)) pi)
pi
""",
['def:pi', 628, 314]
],
]
for _input, expected in cases:
result = scheme_interpreter.tokenize_and_evaluate(_input)
self.assertEqual(expected, result)
def test_compare_operators(self):
# <
self.assertEqual(scheme_interpreter.tokenize_and_evaluate("""
(< 1 2)
(< 2 1)
"""), [True, False])
# not
self.assertEqual(scheme_interpreter.tokenize_and_evaluate("""
(not (< 1 2))
(not false)
"""), [False, True])
# and
self.assertEqual(scheme_interpreter.tokenize_and_evaluate("""
(and true (> (+ 1 2) 3))
(define (x) (x))
(or true (x) (哈哈哈哈随便什么都行!因为这个根本不求值))
(and false (还行吧))
(or (and true false) (not false))
"""), [False, 'def:x', True, False, True])
def test_if(self):
cases = [
['true', [True]],
['false', [False]],
[
'(if true 1 2)',
[1]
],
[
'(if false 1 2)',
[2]
],
[
"""
(define (infinite_loop) (infinite_loop))
(if false (infinite_loop) 2)
""",
['def:infinite_loop', 2]
],
[
("(define (double x) (* x 2))\n"
"(if false (what ever 什么都哈哈哈哈行???) (double 2))\n"),
['def:double', 4]
],
[
"""
(if
(< 1 2) (+ (if (= 1 2) 1 2) 1)
(aksdjhaskjdhakshdkj))
""",
[3]
],
]
for _input, expected in cases:
result = scheme_interpreter.tokenize_and_evaluate(_input)
self.assertEqual(expected, result)
def test_list(self):
self.assertEqual([True, False, True, True]
, scheme_interpreter.tokenize_and_evaluate("(equal? (list 1111 23 4) (list 1111 23 4))"
"(equal? (cons 1 2) (list 1 2))"
"(equal? (cons 1 (cons 2 nil)) (list 1 2))"
"(equal? (cons 1 nil) (list 1))"))
self.assertEqual([1, None], scheme_interpreter.tokenize_and_evaluate("(car (cons 1 nil))"
"(cdr (cons 1 nil))"
))
def test_quote(self):
self.assertEqual(
[1, 1, 2222, True]
, scheme_interpreter.tokenize_and_evaluate("(car '(1(define(x)(x)) 3))"
"(car '(1)) (car '(2222 1)) "
"(equal? (cdr '(222 1)) (cons 1 nil))"))
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment