Skip to content

Instantly share code, notes, and snippets.

@xslendix
Last active September 12, 2021 14:27
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xslendix/680d2cf36b85d55afd7fdea1d154f0db to your computer and use it in GitHub Desktop.
Save xslendix/680d2cf36b85d55afd7fdea1d154f0db to your computer and use it in GitHub Desktop.
#############
# Imports
#############
from string_with_arrows import *
import string
#############
# Constants
#############
DIGITS = '0123456789'
LETTERS = string.ascii_letters
LETTERS_DIGITS = LETTERS + DIGITS
#############
# Errors
#############
class Error:
def __init__(self, pos_start, pos_end, error_name, details):
self.pos_start = pos_start
self.pos_end = pos_end
self.error_name = error_name
self.details = details
def as_string(self):
result = f'{self.error_name}: {self.details}\n'
result += f'File {self.pos_start.fn}, line {self.pos_start.ln + 1}'
result += '\n' + string_with_arrows(self.pos_start.ftxt, self.pos_start, self.pos_end)
return result
class IllegalCharError(Error):
def __init__(self, pos_start, pos_end, details):
super().__init__(pos_start, pos_end, 'Illegal Character', details)
class ExpectedCharError(Error):
def __init__(self, pos_start, pos_end, details):
super().__init__(pos_start, pos_end, 'Expected Character', details)
class InvalidSyntaxError(Error):
def __init__(self, pos_start, pos_end, details=''):
super().__init__(pos_start, pos_end, 'Invalid Syntax', details)
class RTError(Error):
def __init__(self, pos_start, pos_end, details, context):
super().__init__(pos_start, pos_end, 'Runtime Error', details)
self.context = context
def as_string(self):
result = self.generate_traceback()
result += f'{self.error_name}: {self.details}'
result += '\n\n' + string_with_arrows(self.pos_start.ftxt, self.pos_start, self.pos_end)
return result
def generate_traceback(self):
result = ''
pos = self.pos_start
ctx = self.context
while ctx:
result = f' File {pos.fn}, line {str(pos.ln + 1)}, in {ctx.display_name}\n' + result
pos = ctx.parent_entry_pos
ctx = ctx.parent
return 'Traceback (most recent call last):\n' + result
#############
# Position
#############
class Position:
def __init__(self, idx, ln, col, fn, ftxt):
self.idx = idx
self.ln = ln
self.col = col
self.fn = fn
self.ftxt = ftxt
def advance(self, current_char=None):
self.idx += 1
self.col += 1
if current_char == '\n':
self.ln += 1
self.col = 0
return self
def copy(self):
return Position(self.idx, self.ln, self.col, self.fn, self.ftxt)
#############
# Tokens
#############
TT_INT = 'INT'
TT_FLOAT = 'FLOAT'
TT_IDENTIFIER = 'IDENTIFIER'
TT_KEYWORD = 'KEYWORD'
TT_PLUS = 'PLUS'
TT_MINUS = 'MINUS'
TT_MUL = 'MUL'
TT_DIV = 'DIV'
TT_POW = 'POW'
TT_EQ = 'EQ'
TT_LPAREN = 'LPAREN'
TT_RPAREN = 'RPAREN'
TT_EE = 'EE'
TT_NE = 'NE'
TT_LT = 'LT'
TT_GT = 'GT'
TT_LTE = 'LTE'
TT_GTE = 'GTE'
TT_COMMA = 'COMMA'
TT_ARROW = 'ARROW'
TT_EOF = 'EOF'
KEYWORDS = [
'set',
'and',
'or',
'not',
'if',
'then',
'elif',
'else',
'for',
'to',
'step',
'while',
'func'
]
class Token:
def __init__(self, type_, value=None, pos_start=None, pos_end=None):
self.type = type_
self.value = value
if pos_start:
self.pos_start = pos_start.copy()
self.pos_end = pos_start.copy()
self.pos_end.advance()
if pos_end:
self.pos_end = pos_end
def matches(self, type_, value):
return self.type == type_ and self.value == value
def __repr__(self):
if self.value: return f'{self.type}:{self.value}'
return f'{self.type}'
#############
# Lexer
#############
class Lexer:
def __init__(self, fn, text):
self.fn = fn
self.text = text
self.pos = Position(-1, 0, -1, fn, text)
self.current_char = None
self.advance()
def advance(self):
self.pos.advance(self.current_char)
self.current_char = self.text[self.pos.idx] if self.pos.idx < len(self.text) else None
def make_tokens(self):
tokens = []
while self.current_char != None:
if self.current_char in ' \t':
self.advance()
elif self.current_char in DIGITS:
tokens.append(self.make_number())
elif self.current_char in LETTERS:
tokens.append(self.make_identifier())
elif self.current_char == '+':
tokens.append(Token(TT_PLUS, pos_start=self.pos))
self.advance()
elif self.current_char == '-':
tokens.append(Token(TT_MINUS, pos_start=self.pos))
self.advance()
elif self.current_char == '*':
tokens.append(Token(TT_MUL, pos_start=self.pos))
self.advance()
elif self.current_char == '/':
tokens.append(Token(TT_DIV, pos_start=self.pos))
self.advance()
elif self.current_char == '^':
tokens.append(Token(TT_POW, pos_start=self.pos))
self.advance()
elif self.current_char == '(':
tokens.append(Token(TT_LPAREN, pos_start=self.pos))
self.advance()
elif self.current_char == ')':
tokens.append(Token(TT_RPAREN, pos_start=self.pos))
self.advance()
elif self.current_char == '!':
tok, error = self.make_not_equals()
if error: return [], error
tokens.append(tok)
elif self.current_char == '=':
tokens.append(self.make_equals_or_arrow())
elif self.current_char == '<':
tokens.append(self.make_less_than())
elif self.current_char == '>':
tokens.append(self.make_greater_than())
elif self.current_char == ',':
tokens.append(Token(TT_COMMA, pos_start=self.pos))
self.advance()
else:
pos_start = self.pos.copy()
char = self.current_char
self.advance()
return [], IllegalCharError(pos_start, self.pos, "'" + char + "'")
tokens.append(Token(TT_EOF, pos_start=self.pos))
return tokens, None
def make_number(self):
num_str = ''
dot_count = 0
pos_start = self.pos.copy()
while self.current_char != None and self.current_char in DIGITS + '.':
if self.current_char == '.':
if dot_count == 1: break
dot_count += 1
num_str += '.'
else:
num_str += self.current_char
self.advance()
if dot_count == 0:
return Token(TT_INT, int(num_str), pos_start, self.pos)
else:
return Token(TT_FLOAT, float(num_str), pos_start, self.pos)
def make_identifier(self):
id_str = ''
pos_start = self.pos.copy()
while self.current_char != None and self.current_char in LETTERS_DIGITS + '_':
id_str += self.current_char
self.advance()
tok_type = TT_KEYWORD if id_str in KEYWORDS else TT_IDENTIFIER
return Token(tok_type, id_str, pos_start, self.pos)
def make_not_equals(self):
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
return Token(TT_NE, pos_start=pos_start, pos_end=self.pos), None
self.advance()
return None, ExpectedCharError(
pos_start, self.pos,
f"'=' (after '!') expected but got '{self.current_char}'"
)
def make_equals_or_arrow(self):
tok_type = TT_EQ
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
tok_type = TT_EE
elif self.current_char == '>':
self.advance()
tok_type = TT_ARROW
return Token(tok_type, pos_start=pos_start, pos_end=self.pos)
def make_less_than(self):
tok_type = TT_LT
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
tok_type = TT_LTE
return Token(tok_type, pos_start=pos_start, pos_end=self.pos)
def make_greater_than(self):
tok_type = TT_GT
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
tok_type = TT_GTE
return Token(tok_type, pos_start=pos_start, pos_end=self.pos)
#############
# Nodes
#############
class NumberNode:
def __init__(self, tok):
self.tok = tok
self.pos_start = self.tok.pos_start
self.pos_end = self.tok.pos_end
def __repr__(self):
return f'{self.tok}'
class VarAccessNode:
def __init__(self, var_name_tok):
self.var_name_tok = var_name_tok
self.pos_start = self.var_name_tok.pos_start
self.pos_end = self.var_name_tok.pos_end
class VarAssignNode:
def __init__(self, var_name_tok, value_node):
self.var_name_tok = var_name_tok
self.value_node = value_node
self.pos_start = self.var_name_tok.pos_start
self.pos_end = self.value_node.pos_end
class BinOpNode:
def __init__(self, left_node, op_tok, right_node):
self.left_node = left_node
self.op_tok = op_tok
self.right_node = right_node
self.pos_start = self.left_node.pos_start
self.pos_end = self.right_node.pos_end
def __repr__(self):
return f'({self.left_node}, {self.op_tok}, {self.right_node})'
class UnaryOpNode:
def __init__(self, op_tok, node):
self.op_tok = op_tok
self.node = node
self.pos_start = self.op_tok.pos_start
self.pos_end = node.pos_end
def __repr__(self):
return f'({self.op_tok}, {self.node})'
class IfNode:
def __init__(self, cases, else_case):
self.cases = cases
self.else_case = else_case
self.pos_start = self.cases[0][0].pos_start
self.pos_end = (self.else_case or self.cases[len(self.cases) - 1][0]).pos_end
class ForNode:
def __init__(self, var_name_tok, start_value_node, end_value_node, step_value_node, body_node):
self.var_name_tok = var_name_tok
self.start_value_node = start_value_node
self.end_value_node = end_value_node
self.step_value_node = step_value_node
self.body_node = body_node
self.pos_start = self.var_name_tok.pos_start
self.pos_end = self.body_node.pos_end
class WhileNode:
def __init__(self, condition_node, body_node):
self.condition_node = condition_node
self.body_node = body_node
self.pos_start = condition_node.pos_start
self.pos_end = body_node.pos_end
class FuncDefNode:
def __init__(self, var_name_tok, arg_name_toks, body_node):
self.var_name_tok = var_name_tok
self.arg_name_toks = arg_name_toks
self.body_node = body_node
if self.var_name_tok:
self.pos_start = self.var_name_tok.pos_start
elif len(self.arg_name_toks) > 0:
self.pos_start = self.arg_name_toks[0].pos_start
else:
self.pos_start = self.body_node.pos_start
self.pos_end = self.body_node.pos_end
class CallNode:
def __init__(self, node_to_call, arg_nodes):
self.node_to_call = node_to_call
self.arg_nodes = arg_nodes
self.pos_start = self.node_to_call.pos_start
self.pos_end = None
if len(self.arg_nodes) > 0:
self.pos_end = self.arg_nodes[len(self.arg_nodes) - 1].pos_end
else:
self.pos_end = self.node_to_call.pos_end
##############
# Parse result
##############
class ParseResult:
def __init__(self):
self.error = None
self.node = None
self.last_registered_advance_count = 0
self.advance_count = 0
def register_advancement(self):
self.last_registered_advance_count = 1
self.advance_count += 1
def register(self, res):
self.last_registered_advance_count = res.advance_count
self.advance_count += res.advance_count
if res.error: self.error = res.error
return res.node
def success(self, node):
self.node = node
return self
def failure(self, error):
if not self.error or self.advance_count == 0:
self.error = error
return self
#############
# Parser
#############
class Parser:
def __init__(self, tokens):
self.tokens = tokens
self.tok_idx = -1
self.advance()
def advance(self, ):
self.tok_idx += 1
if self.tok_idx < len(self.tokens):
self.current_tok = self.tokens[self.tok_idx]
return self.current_tok
def parse(self):
res = self.expr()
if not res.error and self.current_tok.type != TT_EOF:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '+', '-', '*', '/', '^', '==', '!=', '<', '>', '<=', '>=', 'and' or 'or'"
))
return res
###################################
def if_expr(self):
res = ParseResult()
cases = []
else_case = None
if not self.current_tok.matches(TT_KEYWORD, 'if'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
f"Expected 'if'"
))
res.register_advancement()
self.advance()
condition = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
f"Expected 'then'"
))
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
cases.append((condition, expr))
while self.current_tok.matches(TT_KEYWORD, 'elif'):
res.register_advancement()
self.advance()
condition = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
f"Expected 'then'"
))
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
cases.append((condition, expr))
if self.current_tok.matches(TT_KEYWORD, 'else'):
res.register_advancement()
self.advance()
else_case = res.register(self.expr())
if res.error: return res
return res.success(IfNode(cases, else_case))
def for_expr(self):
res = ParseResult()
if not self.current_tok.matches(TT_KEYWORD, 'for'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'for'"
))
res.register_advancement()
self.advance()
if self.current_tok.type != TT_IDENTIFIER:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier"
))
var_name = self.current_tok
res.register_advancement()
self.advance()
if self.current_tok.type != TT_EQ:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '='"
))
res.register_advancement()
self.advance()
start_value = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'to'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'to'"
))
res.register_advancement()
self.advance()
end_value = res.register(self.expr())
if res.error: return res
if self.current_tok.matches(TT_KEYWORD, 'step'):
res.register_advancement()
self.advance()
step_value = res.register(self.expr())
if res.error: return res
else:
step_value = None
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'then'"
))
res.register_advancement()
self.advance()
body = res.register(self.expr())
if res.error: return res
return res.success(ForNode(var_name, start_value, end_value, step_value, body))
def while_expr(self):
res = ParseResult()
if not self.current_tok.matches(TT_KEYWORD, 'while'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'while'"
))
res.register_advancement()
self.advance()
condition = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'then'"
))
res.register_advancement()
self.advance()
body = res.register(self.expr())
if res.error: return res
return res.success(WhileNode(condition, body))
def call(self):
res = ParseResult()
atom = res.register(self.atom())
if res.error: return res
if self.current_tok.type == TT_LPAREN:
res.register_advancement()
self.advance()
arg_nodes = []
if self.current_tok.type == TT_RPAREN:
res.register_advancement()
self.advance()
else:
arg_nodes.append(res.register(self.expr()))
if res.error:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ')', 'set', 'if', 'for', 'while', 'func', int, float, identifier, '+', '-', ')' or 'not'"
))
while self.current_tok.type == TT_COMMA:
res.register_advancement()
self.advance()
arg_nodes.append(res.register(self.expr()))
if res.error: return res
if self.current_tok != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ',' or ')'"
))
res.register_advancement()
self.advance()
return res.success(CallNode(atom, arg_nodes))
return res.success(atom)
def atom(self):
res = ParseResult()
tok = self.current_tok
if tok.type in (TT_INT, TT_FLOAT):
res.register_advancement()
self.advance()
return res.success(NumberNode(tok))
elif tok.type == TT_IDENTIFIER:
res.register_advancement()
self.advance()
return res.success(VarAccessNode(tok))
elif tok.type == TT_LPAREN:
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
if self.current_tok.type == TT_RPAREN:
res.register_advancement()
self.advance()
return res.success(expr)
else:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ')'"
))
elif tok.matches(TT_KEYWORD, 'if'):
if_expr = res.register(self.if_expr())
if res.error: return res
return res.success(if_expr)
elif tok.matches(TT_KEYWORD, 'for'):
for_expr = res.register(self.for_expr())
if res.error: return res
return res.success(for_expr)
elif tok.matches(TT_KEYWORD, 'while'):
while_expr = res.register(self.while_expr())
if res.error: return res
return res.success(while_expr)
elif tok.matches(TT_KEYWORD, 'func'):
func_def = res.register(self.func_def())
if res.error: return res
return res.success(func_def)
return res.failure(InvalidSyntaxError(
tok.pos_start, tok.pos_end,
"Expected int, float, identifier, '+', '-', '(', 'if', 'for', 'while', 'func'"
))
def power(self):
return self.bin_op(self.call, (TT_POW, self.factor))
def factor(self):
res = ParseResult()
tok = self.current_tok
if tok.type in (TT_PLUS, TT_MINUS):
res.register_advancement()
self.advance()
factor = res.register(self.factor())
if res.error: return res
return res.success(UnaryOpNode(tok, factor))
return self.power()
def term(self):
return self.bin_op(self.factor, (TT_MUL, TT_DIV))
def arith_expr(self):
return self.bin_op(self.term, (TT_PLUS, TT_MINUS))
def comp_expr(self):
res = ParseResult()
if self.current_tok.matches(TT_KEYWORD, 'not'):
op_tok = self.current_tok
res.register_advancement()
self.advance()
node = res.register(self.comp_expr())
if res.error: return res
return res.success(UnaryOpNode(op_tok, node))
node = res.register(self.bin_op(self.arith_expr, (TT_EE, TT_NE, TT_LT, TT_GT, TT_LTE, TT_GTE)))
if res.error:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected int, float, '+', '-' or '(', 'not'"
))
return res.success(node)
def expr(self):
res = ParseResult()
if self.current_tok.matches(TT_KEYWORD, 'set'):
res.register_advancement()
self.advance()
if self.current_tok.type != TT_IDENTIFIER:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
'Expected identifier'
))
var_name = self.current_tok
res.register_advancement()
self.advance()
if self.current_tok.type != TT_EQ:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '='"
))
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
return res.success(VarAssignNode(var_name, expr))
node = res.register(self.bin_op(self.comp_expr, (
(TT_KEYWORD, "and"),
(TT_KEYWORD, "or")
)))
if res.error:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'set', 'not', int, float, identifier, '+', '-' or '('"
))
return res.success(node)
def func_def(self):
res = ParseResult()
if not self.current_tok.matches(TT_KEYWORD, 'func'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'func'"
))
res.register_advancement()
self.advance()
if self.current_tok.type == TT_IDENTIFIER:
var_name_tok = self.current_tok
res.register_advancement()
self.advance()
if self.current_tok.type != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '('"
))
else:
var_name_tok = None
if self.current_tok.type != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier or '('"
))
res.register_advancement()
self.advance()
arg_name_toks = []
if self.current_tok.type == TT_IDENTIFIER:
arg_name_toks.append(self.current_tok)
res.register_advancement()
self.advance()
while self.current_tok.type == TT_COMMA:
res.register_advancement()
self.advance()
if self.current_tok.type != TT_IDENTIFIER:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier"
))
arg_name_toks.append(self.current_tok)
res.register_advancement()
self.advance()
if self.current_tok.type != TT_RPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ',' or ')'"
))
else:
if self.current_tok.type != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier or ')'"
))
res.register_advancement()
self.advance()
if self.current_tok.type != TT_ARROW:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '=>'"
))
res.register_advancement()
self.advance()
node_to_return = res.register(self.expr())
if res.error: return res
return res.success(FuncDefNode(var_name_tok, arg_name_toks, node_to_return))
###################################
def bin_op(self, func_a, ops, func_b = None):
if func_b == None: func_b = func_a
res = ParseResult()
left = res.register(func_a())
if res.error: return res
while self.current_tok.type in ops or (self.current_tok.type, self.current_tok.value) in ops:
op_tok = self.current_tok
res.register_advancement()
self.advance()
right = res.register(func_b())
if res.error: return res
left = BinOpNode(left, op_tok, right)
return res.success(left)
################
# Runtime result
################
class RTResult:
def __init__(self):
self.value = None
self.error = None
def register(self, res):
if res.error: self.error = res.error
return res.value
def success(self, value):
self.value = value
return self
def failure(self, error):
self.error = error
return self
#############
# Values
#############
class Value:
def __init__(self):
self.set_pos()
self.set_context()
def set_pos(self, pos_start=None, pos_end=None):
self.pos_start = pos_start
self.pos_end = pos_end
return self
def set_context(self, context=None):
self.context = context
return self
def added_to(self, other):
return None, self.illegal_operation(other)
def subbed_by(self, other):
return None, self.illegal_operation(other)
def multed_by(self, other):
return None, self.illegal_operation(other)
def dived_by(self, other):
return None, self.illegal_operation(other)
def powed_by(self, other):
return None, self.illegal_operation(other)
def get_comparison_eq(self, other):
return None, self.illegal_operation(other)
def get_comparison_ne(self, other):
return None, self.illegal_operation(other)
def get_comparison_lt(self, other):
return None, self.illegal_operation(other)
def get_comparison_gt(self, other):
return None, self.illegal_operation(other)
def get_comparison_lte(self, other):
return None, self.illegal_operation(other)
def get_comparison_gte(self, other):
return None, self.illegal_operation(other)
def anded_by(self, other):
return None, self.illegal_operation(other)
def ored_by(self, other):
return None, self.illegal_operation(other)
def notted(self, other):
return None, self.illegal_operation(other)
def execute(self, args):
return RTResult().failure(self.illegal_operation())
def copy(self):
raise Exception('No copy method defined')
def is_true(self):
return False
def illegal_operation(self, other=None):
if not other: other = self
return RTError(
self.pos_start, other.pos_end,
'Illegal operation',
self.context
)
class Number(Value):
def __init__(self, value):
super().__init__()
self.value = value
def added_to(self, other):
if isinstance(other, Number):
return Number(self.value + other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def subbed_by(self, other):
if isinstance(other, Number):
return Number(self.value - other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def multed_by(self, other):
if isinstance(other, Number):
return Number(self.value * other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def dived_by(self, other):
if isinstance(other, Number):
if other.value == 0:
return None, RTError(
other.pos_start, other.pos_end,
'Division by zero',
self.context
)
return Number(self.value / other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def powed_by(self, other):
if isinstance(other, Number):
return Number(self.value ** other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_eq(self, other):
if isinstance(other, Number):
return Number(int(self.value == other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_ne(self, other):
if isinstance(other, Number):
return Number(int(self.value != other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_lt(self, other):
if isinstance(other, Number):
return Number(int(self.value < other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_gt(self, other):
if isinstance(other, Number):
return Number(int(self.value > other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_lte(self, other):
if isinstance(other, Number):
return Number(int(self.value <= other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_gte(self, other):
if isinstance(other, Number):
return Number(int(self.value >= other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def anded_by(self, other):
if isinstance(other, Number):
return Number(int(self.value and other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def ored_by(self, other):
if isinstance(other, Number):
return Number(int(self.value or other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def notted(self):
return Number(1 if self.value == 0 else 0).set_context(self.context), None
def copy(self):
copy = Number(self.value)
copy.set_pos(self.pos_start, self.pos_end)
copy.set_context(self.context)
return copy
def is_true(self):
return self.value != 0
def __repr__(self):
return str(self.value)
class Function(Value):
def __init__(self, name, body_node, arg_names):
super().__init__()
self.name = name or "<anonymous>"
self.body_node = body_node
self.arg_names = arg_names
def execute(self, args):
res = RTResult()
interpreter = Interpreter()
new_context = Context(self.name, self.context, self.pos_start)
new_context.symbol_table = SymbolTable(new_context.parent.symbol_table)
if len(args) > len(self.arg_names):
return res.failure(RTError(
self.pos_start, self.pos_end,
f"{len(args) - len(self.arg_names)} too many args passed into '{self.name}'",
self.context
))
if len(args) < len(self.arg_names):
return res.failure(RTError(
self.pos_start, self.pos_end,
f"{len(self.arg_names) - len(args)} too few args passed into '{self.name}'",
self.context
))
for i in range(len(args)):
arg_name = self.arg_names[i]
arg_value = args[i]
arg_value.set_context(new_context)
new_context.symbol_table.set(arg_name, arg_value)
value = res.register(interpreter.visit(self.body_node, new_context))
if res.error: return res
return res.success(value)
def copy(self):
copy = Function(self.name, self.body_node, self.arg_names)
copy.set_context(self.context)
copy.set_pos(self.pos_start, self.pos_end)
return copy
def __repr__(self):
return f"<function {self.name}>"
#############
# Context
#############
class Context:
def __init__(self, display_name, parent=None, parent_entry_pos=None):
self.display_name = display_name
self.parent = parent
self.parent_entry_pos = parent_entry_pos
self.symbol_table = None
#############
# Symbol table
#############
class SymbolTable:
def __init__(self, parent=None):
self.symbols = {}
self.parent = None
def get(self, name):
value = self.symbols.get(name, None)
if value == None and self.parent:
return self.parent.get(name)
return value
def set(self, name, value):
self.symbols[name] = value
def remove(self, name):
del self.symbols[name]
#############
# Interpreter
#############
class Interpreter:
def visit(self, node, context):
method_name = f'visit_{type(node).__name__}'
method = getattr(self, method_name, self.no_visit_method)
return method(node, context)
def no_visit_method(self, node, context):
raise Exception(f'No visit_{type(node).__name__} method defined')
###################################
def visit_NumberNode(self, node, context):
return RTResult().success(
Number(node.tok.value).set_context(context).set_pos(node.pos_start, node.pos_end)
)
def visit_VarAccessNode(self, node, context):
res = RTResult()
var_name = node.var_name_tok.value
value = context.symbol_table.get(var_name)
if not value:
return res.failure(RTError(
node.pos_start, node.pos_end,
f"{var_name} is not defined",
context
))
value = value.copy().set_pos(node.pos_start, node.pos_end)
return res.success(value)
def visit_VarAssignNode(self, node, context):
res = RTResult()
var_name = node.var_name_tok.value
value = res.register(self.visit(node.value_node, context))
if res.error: return res
context.symbol_table.set(var_name, value)
return res.success(value)
def visit_BinOpNode(self, node, context):
res = RTResult()
left = res.register(self.visit(node.left_node, context))
if res.error: return res
right = res.register(self.visit(node.right_node, context))
if res.error: return res
if node.op_tok.type == TT_PLUS:
result, error = left.added_to(right)
elif node.op_tok.type == TT_MINUS:
result, error = left.subbed_by(right)
elif node.op_tok.type == TT_MUL:
result, error = left.multed_by(right)
elif node.op_tok.type == TT_DIV:
result, error = left.dived_by(right)
elif node.op_tok.type == TT_POW:
result, error = left.powed_by(right)
####################################
elif node.op_tok.type == TT_EE:
result, error = left.get_comparison_eq(right)
elif node.op_tok.type == TT_NE:
result, error = left.get_comparison_ne(right)
elif node.op_tok.type == TT_LT:
result, error = left.get_comparison_lt(right)
elif node.op_tok.type == TT_GT:
result, error = left.get_comparison_gt(right)
elif node.op_tok.type == TT_LTE:
result, error = left.get_comparison_lte(right)
elif node.op_tok.type == TT_GTE:
result, error = left.get_comparison_gte(right)
####################################
elif node.op_tok.matches(TT_KEYWORD, 'and'):
result, error = left.anded_by(right)
elif node.op_tok.matches(TT_KEYWORD, 'or'):
result, error = left.ored_by(right)
if error:
return res.failure(error)
else:
return res.success(result.set_pos(node.pos_start, node.pos_end))
def visit_UnaryOpNode(self, node, context):
res = RTResult()
number = res.register(self.visit(node.node, context))
if res.error: return res
error = None
if node.op_tok.type == TT_MINUS:
number, error = number.multed_by(Number(-1))
elif node.op_tok.matches(TT_KEYWORD, 'not'):
number, error = number.notted()
if error:
return res.failure(error)
else:
return res.success(number.set_pos(node.pos_start, node.pos_end))
def visit_IfNode(self, node, context):
res = RTResult()
for condition, expr in node.cases:
condition_value = res.register(self.visit(condition, context))
if res.error: return res
if condition_value.is_true():
expr_value = res.register(self.visit(expr, context))
if res.error: return res
return res.success(expr_value)
if node.else_case:
else_value = res.register(self.visit(node.else_case, context))
if res.error: return res
return res.success(else_value)
return res.success(None)
def visit_ForNode(self, node, context):
res = RTResult()
start_value = res.register(self.visit(node.start_value_node, context))
if res.error: return res
end_value = res.register(self.visit(node.end_value_node, context))
if res.error: return res
if node.step_value_node:
step_value = res.register(self.visit(node.step_value_node, context))
if res.error: return res
else:
step_value = Number(1)
i = start_value.value
if step_value.value > 0:
condition = lambda: i < end_value.value
else:
condition = lambda: i > end_value.value
while condition():
context.symbol_table.set(node.var_name_tok.value, Number(i))
i += step_value.value
res.register(self.visit(node.body_node, context))
if res.error: return res
return res.success(None)
def visit_WhileNode(self, node, context):
res = RTResult()
while True:
condition = res.register(self.visit(node.condition_node, context))
if res.error: return res
if not condition.is_true(): break
res.register(self.visit(node.body_node, context))
if res.error: return res
return res.success(None)
def visit_FuncDefNode(self, node, context):
res = RTResult()
func_name = node.var_name_tok.value if node.var_name_tok else None
body_node = node.body_node
arg_names = [arg_name.value for arg_name in node.arg_name_toks]
func_value = Function(func_name, body_node, arg_names).set_context(context).set_pos(node.pos_start, node.pos_end)
if node.var_name_tok:
context.symbol_table.set(func_name, func_value)
return res.success(func_value)
def visit_CallNode(self, node, context):
res = RTResult()
args = []
value_to_call = res.register(self.visit(node.node_to_call, context))
if res.error: return res
value_to_call = value_to_call.copy().set_pos(node.pos_start, node.pos_end)
for arg_node in node.arg_nodes:
args.append(res.register(self.visit(arg_node, context)))
if res.error: return res
return_value = res.register(value_to_call.execute(args))
if res.error: return res
return res.success(return_value)
#############
# Run
#############
global_symbol_table = SymbolTable()
global_symbol_table.set('NaN', Number(0))
global_symbol_table.set('True', Number(1))
global_symbol_table.set('False', Number(0))
def run(fn, text):
# Generate tokens
lexer = Lexer(fn, text)
tokens, error = lexer.make_tokens()
if error: return None, error
# Generate AST
parser = Parser(tokens)
ast = parser.parse()
if ast.error: return None, ast.error
# Run program
interpreter = Interpreter()
context = Context('<program>')
context.symbol_table = global_symbol_table
result = interpreter.visit(ast.node, context)
return result.value, result.error
#############
# Imports
#############
from string_with_arrows import *
import string
#############
# Constants
#############
DIGITS = '0123456789'
LETTERS = string.ascii_letters
LETTERS_DIGITS = LETTERS + DIGITS
#############
# Errors
#############
class Error:
def __init__(self, pos_start, pos_end, error_name, details):
self.pos_start = pos_start
self.pos_end = pos_end
self.error_name = error_name
self.details = details
def as_string(self):
result = f'{self.error_name}: {self.details}\n'
result += f'File {self.pos_start.fn}, line {self.pos_start.ln + 1}'
result += '\n' + string_with_arrows(self.pos_start.ftxt, self.pos_start, self.pos_end)
return result
class IllegalCharError(Error):
def __init__(self, pos_start, pos_end, details):
super().__init__(pos_start, pos_end, 'Illegal Character', details)
class ExpectedCharError(Error):
def __init__(self, pos_start, pos_end, details):
super().__init__(pos_start, pos_end, 'Expected Character', details)
class InvalidSyntaxError(Error):
def __init__(self, pos_start, pos_end, details=''):
super().__init__(pos_start, pos_end, 'Invalid Syntax', details)
class RTError(Error):
def __init__(self, pos_start, pos_end, details, context):
super().__init__(pos_start, pos_end, 'Runtime Error', details)
self.context = context
def as_string(self):
result = self.generate_traceback()
result += f'{self.error_name}: {self.details}'
result += '\n\n' + string_with_arrows(self.pos_start.ftxt, self.pos_start, self.pos_end)
return result
def generate_traceback(self):
result = ''
pos = self.pos_start
ctx = self.context
while ctx:
result = f' File {pos.fn}, line {str(pos.ln + 1)}, in {ctx.display_name}\n' + result
pos = ctx.parent_entry_pos
ctx = ctx.parent
return 'Traceback (most recent call last):\n' + result
#############
# Position
#############
class Position:
def __init__(self, idx, ln, col, fn, ftxt):
self.idx = idx
self.ln = ln
self.col = col
self.fn = fn
self.ftxt = ftxt
def advance(self, current_char=None):
self.idx += 1
self.col += 1
if current_char == '\n':
self.ln += 1
self.col = 0
return self
def copy(self):
return Position(self.idx, self.ln, self.col, self.fn, self.ftxt)
#############
# Tokens
#############
TT_INT = 'INT'
TT_FLOAT = 'FLOAT'
TT_IDENTIFIER = 'IDENTIFIER'
TT_KEYWORD = 'KEYWORD'
TT_PLUS = 'PLUS'
TT_MINUS = 'MINUS'
TT_MUL = 'MUL'
TT_DIV = 'DIV'
TT_POW = 'POW'
TT_EQ = 'EQ'
TT_LPAREN = 'LPAREN'
TT_RPAREN = 'RPAREN'
TT_EE = 'EE'
TT_NE = 'NE'
TT_LT = 'LT'
TT_GT = 'GT'
TT_LTE = 'LTE'
TT_GTE = 'GTE'
TT_COMMA = 'COMMA'
TT_ARROW = 'ARROW'
TT_EOF = 'EOF'
KEYWORDS = [
'set',
'and',
'or',
'not',
'if',
'then',
'elif',
'else',
'for',
'to',
'step',
'while',
'func'
]
class Token:
def __init__(self, type_, value=None, pos_start=None, pos_end=None):
self.type = type_
self.value = value
if pos_start:
self.pos_start = pos_start.copy()
self.pos_end = pos_start.copy()
self.pos_end.advance()
if pos_end:
self.pos_end = pos_end
def matches(self, type_, value):
return self.type == type_ and self.value == value
def __repr__(self):
if self.value: return f'{self.type}:{self.value}'
return f'{self.type}'
#############
# Lexer
#############
class Lexer:
def __init__(self, fn, text):
self.fn = fn
self.text = text
self.pos = Position(-1, 0, -1, fn, text)
self.current_char = None
self.advance()
def advance(self):
self.pos.advance(self.current_char)
self.current_char = self.text[self.pos.idx] if self.pos.idx < len(self.text) else None
def make_tokens(self):
tokens = []
while self.current_char != None:
if self.current_char in ' \t':
self.advance()
elif self.current_char in DIGITS:
tokens.append(self.make_number())
elif self.current_char in LETTERS:
tokens.append(self.make_identifier())
elif self.current_char == '+':
tokens.append(Token(TT_PLUS, pos_start=self.pos))
self.advance()
elif self.current_char == '-':
tokens.append(Token(TT_MINUS, pos_start=self.pos))
self.advance()
elif self.current_char == '*':
tokens.append(Token(TT_MUL, pos_start=self.pos))
self.advance()
elif self.current_char == '/':
tokens.append(Token(TT_DIV, pos_start=self.pos))
self.advance()
elif self.current_char == '^':
tokens.append(Token(TT_POW, pos_start=self.pos))
self.advance()
elif self.current_char == '(':
tokens.append(Token(TT_LPAREN, pos_start=self.pos))
self.advance()
elif self.current_char == ')':
tokens.append(Token(TT_RPAREN, pos_start=self.pos))
self.advance()
elif self.current_char == '!':
tok, error = self.make_not_equals()
if error: return [], error
tokens.append(tok)
elif self.current_char == '=':
tokens.append(self.make_equals_or_arrow())
elif self.current_char == '<':
tokens.append(self.make_less_than())
elif self.current_char == '>':
tokens.append(self.make_greater_than())
elif self.current_char == ',':
tokens.append(Token(TT_COMMA, pos_start=self.pos))
self.advance()
else:
pos_start = self.pos.copy()
char = self.current_char
self.advance()
return [], IllegalCharError(pos_start, self.pos, "'" + char + "'")
tokens.append(Token(TT_EOF, pos_start=self.pos))
return tokens, None
def make_number(self):
num_str = ''
dot_count = 0
pos_start = self.pos.copy()
while self.current_char != None and self.current_char in DIGITS + '.':
if self.current_char == '.':
if dot_count == 1: break
dot_count += 1
num_str += '.'
else:
num_str += self.current_char
self.advance()
if dot_count == 0:
return Token(TT_INT, int(num_str), pos_start, self.pos)
else:
return Token(TT_FLOAT, float(num_str), pos_start, self.pos)
def make_identifier(self):
id_str = ''
pos_start = self.pos.copy()
while self.current_char != None and self.current_char in LETTERS_DIGITS + '_':
id_str += self.current_char
self.advance()
tok_type = TT_KEYWORD if id_str in KEYWORDS else TT_IDENTIFIER
return Token(tok_type, id_str, pos_start, self.pos)
def make_not_equals(self):
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
return Token(TT_NE, pos_start=pos_start, pos_end=self.pos), None
self.advance()
return None, ExpectedCharError(
pos_start, self.pos,
f"'=' (after '!') expected but got '{self.current_char}'"
)
def make_equals_or_arrow(self):
tok_type = TT_EQ
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
tok_type = TT_EE
elif self.current_char == '>':
self.advance()
tok_type = TT_ARROW
return Token(tok_type, pos_start=pos_start, pos_end=self.pos)
def make_less_than(self):
tok_type = TT_LT
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
tok_type = TT_LTE
return Token(tok_type, pos_start=pos_start, pos_end=self.pos)
def make_greater_than(self):
tok_type = TT_GT
pos_start = self.pos.copy()
self.advance()
if self.current_char == '=':
self.advance()
tok_type = TT_GTE
return Token(tok_type, pos_start=pos_start, pos_end=self.pos)
#############
# Nodes
#############
class NumberNode:
def __init__(self, tok):
self.tok = tok
self.pos_start = self.tok.pos_start
self.pos_end = self.tok.pos_end
def __repr__(self):
return f'{self.tok}'
class VarAccessNode:
def __init__(self, var_name_tok):
self.var_name_tok = var_name_tok
self.pos_start = self.var_name_tok.pos_start
self.pos_end = self.var_name_tok.pos_end
class VarAssignNode:
def __init__(self, var_name_tok, value_node):
self.var_name_tok = var_name_tok
self.value_node = value_node
self.pos_start = self.var_name_tok.pos_start
self.pos_end = self.value_node.pos_end
class BinOpNode:
def __init__(self, left_node, op_tok, right_node):
self.left_node = left_node
self.op_tok = op_tok
self.right_node = right_node
self.pos_start = self.left_node.pos_start
self.pos_end = self.right_node.pos_end
def __repr__(self):
return f'({self.left_node}, {self.op_tok}, {self.right_node})'
class UnaryOpNode:
def __init__(self, op_tok, node):
self.op_tok = op_tok
self.node = node
self.pos_start = self.op_tok.pos_start
self.pos_end = node.pos_end
def __repr__(self):
return f'({self.op_tok}, {self.node})'
class IfNode:
def __init__(self, cases, else_case):
self.cases = cases
self.else_case = else_case
self.pos_start = self.cases[0][0].pos_start
self.pos_end = (self.else_case or self.cases[len(self.cases) - 1][0]).pos_end
class ForNode:
def __init__(self, var_name_tok, start_value_node, end_value_node, step_value_node, body_node):
self.var_name_tok = var_name_tok
self.start_value_node = start_value_node
self.end_value_node = end_value_node
self.step_value_node = step_value_node
self.body_node = body_node
self.pos_start = self.var_name_tok.pos_start
self.pos_end = self.body_node.pos_end
class WhileNode:
def __init__(self, condition_node, body_node):
self.condition_node = condition_node
self.body_node = body_node
self.pos_start = condition_node.pos_start
self.pos_end = body_node.pos_end
class FuncDefNode:
def __init__(self, var_name_tok, arg_name_toks, body_node):
self.var_name_tok = var_name_tok
self.arg_name_toks = arg_name_toks
self.body_node = body_node
if self.var_name_tok:
self.pos_start = self.var_name_tok.pos_start
elif len(self.arg_name_toks) > 0:
self.pos_start = self.arg_name_toks[0].pos_start
else:
self.pos_start = self.body_node.pos_start
self.pos_end = self.body_node.pos_end
class CallNode:
def __init__(self, node_to_call, arg_nodes):
self.node_to_call = node_to_call
self.arg_nodes = arg_nodes
self.pos_start = self.node_to_call.pos_start
self.pos_end = None
if len(self.arg_nodes) > 0:
self.pos_end = self.arg_nodes[len(self.arg_nodes) - 1].pos_end
else:
self.pos_end = self.node_to_call.pos_end
##############
# Parse result
##############
class ParseResult:
def __init__(self):
self.error = None
self.node = None
self.last_registered_advance_count = 0
self.advance_count = 0
def register_advancement(self):
self.last_registered_advance_count = 1
self.advance_count += 1
def register(self, res):
self.last_registered_advance_count = res.advance_count
self.advance_count += res.advance_count
if res.error: self.error = res.error
return res.node
def success(self, node):
self.node = node
return self
def failure(self, error):
if not self.error or self.last_registered_advance_count == 0:
self.error = error
return self
#############
# Parser
#############
class Parser:
def __init__(self, tokens):
self.tokens = tokens
self.tok_idx = -1
self.advance()
def advance(self, ):
self.tok_idx += 1
if self.tok_idx < len(self.tokens):
self.current_tok = self.tokens[self.tok_idx]
return self.current_tok
def parse(self):
res = self.expr()
if not res.error and self.current_tok.type != TT_EOF:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '+', '-', '*', '/', '^', '==', '!=', '<', '>', '<=', '>=', 'and' or 'or'"
))
return res
###################################
def if_expr(self):
res = ParseResult()
cases = []
else_case = None
if not self.current_tok.matches(TT_KEYWORD, 'if'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
f"Expected 'if'"
))
res.register_advancement()
self.advance()
condition = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
f"Expected 'then'"
))
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
cases.append((condition, expr))
while self.current_tok.matches(TT_KEYWORD, 'elif'):
res.register_advancement()
self.advance()
condition = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
f"Expected 'then'"
))
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
cases.append((condition, expr))
if self.current_tok.matches(TT_KEYWORD, 'else'):
res.register_advancement()
self.advance()
else_case = res.register(self.expr())
if res.error: return res
return res.success(IfNode(cases, else_case))
def for_expr(self):
res = ParseResult()
if not self.current_tok.matches(TT_KEYWORD, 'for'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'for'"
))
res.register_advancement()
self.advance()
if self.current_tok.type != TT_IDENTIFIER:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier"
))
var_name = self.current_tok
res.register_advancement()
self.advance()
if self.current_tok.type != TT_EQ:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '='"
))
res.register_advancement()
self.advance()
start_value = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'to'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'to'"
))
res.register_advancement()
self.advance()
end_value = res.register(self.expr())
if res.error: return res
if self.current_tok.matches(TT_KEYWORD, 'step'):
res.register_advancement()
self.advance()
step_value = res.register(self.expr())
if res.error: return res
else:
step_value = None
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'then'"
))
res.register_advancement()
self.advance()
body = res.register(self.expr())
if res.error: return res
return res.success(ForNode(var_name, start_value, end_value, step_value, body))
def while_expr(self):
res = ParseResult()
if not self.current_tok.matches(TT_KEYWORD, 'while'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'while'"
))
res.register_advancement()
self.advance()
condition = res.register(self.expr())
if res.error: return res
if not self.current_tok.matches(TT_KEYWORD, 'then'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'then'"
))
res.register_advancement()
self.advance()
body = res.register(self.expr())
if res.error: return res
return res.success(WhileNode(condition, body))
def call(self):
res = ParseResult()
atom = res.register(self.atom())
if res.error: return res
if self.current_tok.type == TT_LPAREN:
res.register_advancement()
self.advance()
arg_nodes = []
if self.current_tok.type == TT_RPAREN:
res.register_advancement()
self.advance()
else:
arg_nodes.append(res.register(self.expr()))
if res.error:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ')', 'set', 'if', 'for', 'while', 'func', int, float, identifier, '+', '-', ')' or 'not'"
))
while self.current_tok.type == TT_COMMA:
res.register_advancement()
self.advance()
arg_nodes.append(res.register(self.expr()))
if res.error: return res
if self.current_tok != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ',' or ')'"
))
res.register_advancement()
self.advance()
return res.success(CallNode(atom, arg_nodes))
return res.success(atom)
def atom(self):
res = ParseResult()
tok = self.current_tok
if tok.type in (TT_INT, TT_FLOAT):
res.register_advancement()
self.advance()
return res.success(NumberNode(tok))
elif tok.type == TT_IDENTIFIER:
res.register_advancement()
self.advance()
return res.success(VarAccessNode(tok))
elif tok.type == TT_LPAREN:
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
if self.current_tok.type == TT_RPAREN:
res.register_advancement()
self.advance()
return res.success(expr)
else:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ')'"
))
elif tok.matches(TT_KEYWORD, 'if'):
if_expr = res.register(self.if_expr())
if res.error: return res
return res.success(if_expr)
elif tok.matches(TT_KEYWORD, 'for'):
for_expr = res.register(self.for_expr())
if res.error: return res
return res.success(for_expr)
elif tok.matches(TT_KEYWORD, 'while'):
while_expr = res.register(self.while_expr())
if res.error: return res
return res.success(while_expr)
elif tok.matches(TT_KEYWORD, 'func'):
func_def = res.register(self.func_def())
if res.error: return res
return res.success(func_def)
return res.failure(InvalidSyntaxError(
tok.pos_start, tok.pos_end,
"Expected int, float, identifier, '+', '-', '(', 'if', 'for', 'while', 'func'"
))
def power(self):
return self.bin_op(self.call, (TT_POW, self.factor))
def factor(self):
res = ParseResult()
tok = self.current_tok
if tok.type in (TT_PLUS, TT_MINUS):
res.register_advancement()
self.advance()
factor = res.register(self.factor())
if res.error: return res
return res.success(UnaryOpNode(tok, factor))
return self.power()
def term(self):
return self.bin_op(self.factor, (TT_MUL, TT_DIV))
def arith_expr(self):
return self.bin_op(self.term, (TT_PLUS, TT_MINUS))
def comp_expr(self):
res = ParseResult()
if self.current_tok.matches(TT_KEYWORD, 'not'):
op_tok = self.current_tok
res.register_advancement()
self.advance()
node = res.register(self.comp_expr())
if res.error: return res
return res.success(UnaryOpNode(op_tok, node))
node = res.register(self.bin_op(self.arith_expr, (TT_EE, TT_NE, TT_LT, TT_GT, TT_LTE, TT_GTE)))
if res.error:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected int, float, '+', '-' or '(', 'not'"
))
return res.success(node)
def expr(self):
res = ParseResult()
if self.current_tok.matches(TT_KEYWORD, 'set'):
res.register_advancement()
self.advance()
if self.current_tok.type != TT_IDENTIFIER:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
'Expected identifier'
))
var_name = self.current_tok
res.register_advancement()
self.advance()
if self.current_tok.type != TT_EQ:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '='"
))
res.register_advancement()
self.advance()
expr = res.register(self.expr())
if res.error: return res
return res.success(VarAssignNode(var_name, expr))
node = res.register(self.bin_op(self.comp_expr, (
(TT_KEYWORD, "and"),
(TT_KEYWORD, "or")
)))
if res.error:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'set', 'not', int, float, identifier, '+', '-' or '('"
))
return res.success(node)
def func_def(self):
res = ParseResult()
if not self.current_tok.matches(TT_KEYWORD, 'func'):
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected 'func'"
))
res.register_advancement()
self.advance()
if self.current_tok.type == TT_IDENTIFIER:
var_name_tok = self.current_tok
res.register_advancement()
self.advance()
if self.current_tok.type != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '('"
))
else:
var_name_tok = None
if self.current_tok.type != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier or '('"
))
res.register_advancement()
self.advance()
arg_name_toks = []
if self.current_tok.type == TT_IDENTIFIER:
arg_name_toks.append(self.current_tok)
res.register_advancement()
self.advance()
while self.current_tok.type == TT_COMMA:
res.register_advancement()
self.advance()
if self.current_tok.type != TT_IDENTIFIER:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier"
))
arg_name_toks.append(self.current_tok)
res.register_advancement()
self.advance()
if self.current_tok.type != TT_RPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected ',' or ')'"
))
else:
if self.current_tok.type != TT_LPAREN:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected identifier or ')'"
))
res.register_advancement()
self.advance()
if self.current_tok.type != TT_ARROW:
return res.failure(InvalidSyntaxError(
self.current_tok.pos_start, self.current_tok.pos_end,
"Expected '=>'"
))
res.register_advancement()
self.advance()
node_to_return = res.register(self.expr())
if res.error: return res
return res.success(FuncDefNode(var_name_tok, arg_name_toks, node_to_return))
###################################
def bin_op(self, func_a, ops, func_b = None):
if func_b == None: func_b = func_a
res = ParseResult()
left = res.register(func_a())
if res.error: return res
while self.current_tok.type in ops or (self.current_tok.type, self.current_tok.value) in ops:
op_tok = self.current_tok
res.register_advancement()
self.advance()
right = res.register(func_b())
if res.error: return res
left = BinOpNode(left, op_tok, right)
return res.success(left)
################
# Runtime result
################
class RTResult:
def __init__(self):
self.value = None
self.error = None
def register(self, res):
if res.error: self.error = res.error
return res.value
def success(self, value):
self.value = value
return self
def failure(self, error):
self.error = error
return self
#############
# Values
#############
class Value:
def __init__(self):
self.set_pos()
self.set_context()
def set_pos(self, pos_start=None, pos_end=None):
self.pos_start = pos_start
self.pos_end = pos_end
return self
def set_context(self, context=None):
self.context = context
return self
def added_to(self, other):
return None, self.illegal_operation(other)
def subbed_by(self, other):
return None, self.illegal_operation(other)
def multed_by(self, other):
return None, self.illegal_operation(other)
def dived_by(self, other):
return None, self.illegal_operation(other)
def powed_by(self, other):
return None, self.illegal_operation(other)
def get_comparison_eq(self, other):
return None, self.illegal_operation(other)
def get_comparison_ne(self, other):
return None, self.illegal_operation(other)
def get_comparison_lt(self, other):
return None, self.illegal_operation(other)
def get_comparison_gt(self, other):
return None, self.illegal_operation(other)
def get_comparison_lte(self, other):
return None, self.illegal_operation(other)
def get_comparison_gte(self, other):
return None, self.illegal_operation(other)
def anded_by(self, other):
return None, self.illegal_operation(other)
def ored_by(self, other):
return None, self.illegal_operation(other)
def notted(self, other):
return None, self.illegal_operation(other)
def execute(self, args):
return RTResult().failure(self.illegal_operation())
def copy(self):
raise Exception('No copy method defined')
def is_true(self):
return False
def illegal_operation(self, other=None):
if not other: other = self
return RTError(
self.pos_start, other.pos_end,
'Illegal operation',
self.context
)
class Number(Value):
def __init__(self, value):
super().__init__()
self.value = value
def added_to(self, other):
if isinstance(other, Number):
return Number(self.value + other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def subbed_by(self, other):
if isinstance(other, Number):
return Number(self.value - other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def multed_by(self, other):
if isinstance(other, Number):
return Number(self.value * other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def dived_by(self, other):
if isinstance(other, Number):
if other.value == 0:
return None, RTError(
other.pos_start, other.pos_end,
'Division by zero',
self.context
)
return Number(self.value / other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def powed_by(self, other):
if isinstance(other, Number):
return Number(self.value ** other.value).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_eq(self, other):
if isinstance(other, Number):
return Number(int(self.value == other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_ne(self, other):
if isinstance(other, Number):
return Number(int(self.value != other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_lt(self, other):
if isinstance(other, Number):
return Number(int(self.value < other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_gt(self, other):
if isinstance(other, Number):
return Number(int(self.value > other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_lte(self, other):
if isinstance(other, Number):
return Number(int(self.value <= other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def get_comparison_gte(self, other):
if isinstance(other, Number):
return Number(int(self.value >= other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def anded_by(self, other):
if isinstance(other, Number):
return Number(int(self.value and other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def ored_by(self, other):
if isinstance(other, Number):
return Number(int(self.value or other.value)).set_context(self.context), None
else:
return None, Value.illegal_operation(self, other)
def notted(self):
return Number(1 if self.value == 0 else 0).set_context(self.context), None
def copy(self):
copy = Number(self.value)
copy.set_pos(self.pos_start, self.pos_end)
copy.set_context(self.context)
return copy
def is_true(self):
return self.value != 0
def __repr__(self):
return str(self.value)
class Function(Value):
def __init__(self, name, body_node, arg_names):
super().__init__()
self.name = name or "<anonymous>"
self.body_node = body_node
self.arg_names = arg_names
def execute(self, args):
res = RTResult()
interpreter = Interpreter()
new_context = Context(self.name, self.context, self.pos_start)
new_context.symbol_table = SymbolTable(new_context.parent.symbol_table)
if len(args) > len(self.arg_names):
return res.failure(RTError(
self.pos_start, self.pos_end,
f"{len(args) - len(self.arg_names)} too many args passed into '{self.name}'",
self.context
))
if len(args) < len(self.arg_names):
return res.failure(RTError(
self.pos_start, self.pos_end,
f"{len(self.arg_names) - len(args)} too few args passed into '{self.name}'",
self.context
))
for i in range(len(args)):
arg_name = self.arg_names[i]
arg_value = args[i]
arg_value.set_context(new_context)
new_context.symbol_table.set(arg_name, arg_value)
value = res.register(interpreter.visit(self.body_node, new_context))
if res.error: return res
return res.success(value)
def copy(self):
copy = Function(self.name, self.body_node, self.arg_names)
copy.set_context(self.context)
copy.set_pos(self.pos_start, self.pos_end)
return copy
def __repr__(self):
return f"<function {self.name}>"
#############
# Context
#############
class Context:
def __init__(self, display_name, parent=None, parent_entry_pos=None):
self.display_name = display_name
self.parent = parent
self.parent_entry_pos = parent_entry_pos
self.symbol_table = None
#############
# Symbol table
#############
class SymbolTable:
def __init__(self, parent=None):
self.symbols = {}
self.parent = None
def get(self, name):
value = self.symbols.get(name, None)
if value == None and self.parent:
return self.parent.get(name)
return value
def set(self, name, value):
self.symbols[name] = value
def remove(self, name):
del self.symbols[name]
#############
# Interpreter
#############
class Interpreter:
def visit(self, node, context):
method_name = f'visit_{type(node).__name__}'
method = getattr(self, method_name, self.no_visit_method)
return method(node, context)
def no_visit_method(self, node, context):
raise Exception(f'No visit_{type(node).__name__} method defined')
###################################
def visit_NumberNode(self, node, context):
return RTResult().success(
Number(node.tok.value).set_context(context).set_pos(node.pos_start, node.pos_end)
)
def visit_VarAccessNode(self, node, context):
res = RTResult()
var_name = node.var_name_tok.value
value = context.symbol_table.get(var_name)
if not value:
return res.failure(RTError(
node.pos_start, node.pos_end,
f"{var_name} is not defined",
context
))
value = value.copy().set_pos(node.pos_start, node.pos_end)
return res.success(value)
def visit_VarAssignNode(self, node, context):
res = RTResult()
var_name = node.var_name_tok.value
value = res.register(self.visit(node.value_node, context))
if res.error: return res
context.symbol_table.set(var_name, value)
return res.success(value)
def visit_BinOpNode(self, node, context):
res = RTResult()
left = res.register(self.visit(node.left_node, context))
if res.error: return res
right = res.register(self.visit(node.right_node, context))
if res.error: return res
if node.op_tok.type == TT_PLUS:
result, error = left.added_to(right)
elif node.op_tok.type == TT_MINUS:
result, error = left.subbed_by(right)
elif node.op_tok.type == TT_MUL:
result, error = left.multed_by(right)
elif node.op_tok.type == TT_DIV:
result, error = left.dived_by(right)
elif node.op_tok.type == TT_POW:
result, error = left.powed_by(right)
####################################
elif node.op_tok.type == TT_EE:
result, error = left.get_comparison_eq(right)
elif node.op_tok.type == TT_NE:
result, error = left.get_comparison_ne(right)
elif node.op_tok.type == TT_LT:
result, error = left.get_comparison_lt(right)
elif node.op_tok.type == TT_GT:
result, error = left.get_comparison_gt(right)
elif node.op_tok.type == TT_LTE:
result, error = left.get_comparison_lte(right)
elif node.op_tok.type == TT_GTE:
result, error = left.get_comparison_gte(right)
####################################
elif node.op_tok.matches(TT_KEYWORD, 'and'):
result, error = left.anded_by(right)
elif node.op_tok.matches(TT_KEYWORD, 'or'):
result, error = left.ored_by(right)
if error:
return res.failure(error)
else:
return res.success(result.set_pos(node.pos_start, node.pos_end))
def visit_UnaryOpNode(self, node, context):
res = RTResult()
number = res.register(self.visit(node.node, context))
if res.error: return res
error = None
if node.op_tok.type == TT_MINUS:
number, error = number.multed_by(Number(-1))
elif node.op_tok.matches(TT_KEYWORD, 'not'):
number, error = number.notted()
if error:
return res.failure(error)
else:
return res.success(number.set_pos(node.pos_start, node.pos_end))
def visit_IfNode(self, node, context):
res = RTResult()
for condition, expr in node.cases:
condition_value = res.register(self.visit(condition, context))
if res.error: return res
if condition_value.is_true():
expr_value = res.register(self.visit(expr, context))
if res.error: return res
return res.success(expr_value)
if node.else_case:
else_value = res.register(self.visit(node.else_case, context))
if res.error: return res
return res.success(else_value)
return res.success(None)
def visit_ForNode(self, node, context):
res = RTResult()
start_value = res.register(self.visit(node.start_value_node, context))
if res.error: return res
end_value = res.register(self.visit(node.end_value_node, context))
if res.error: return res
if node.step_value_node:
step_value = res.register(self.visit(node.step_value_node, context))
if res.error: return res
else:
step_value = Number(1)
i = start_value.value
if step_value.value > 0:
condition = lambda: i < end_value.value
else:
condition = lambda: i > end_value.value
while condition():
context.symbol_table.set(node.var_name_tok.value, Number(i))
i += step_value.value
res.register(self.visit(node.body_node, context))
if res.error: return res
return res.success(None)
def visit_WhileNode(self, node, context):
res = RTResult()
while True:
condition = res.register(self.visit(node.condition_node, context))
if res.error: return res
if not condition.is_true(): break
res.register(self.visit(node.body_node, context))
if res.error: return res
return res.success(None)
def visit_FuncDefNode(self, node, context):
res = RTResult()
func_name = node.var_name_tok.value if node.var_name_tok else None
body_node = node.body_node
arg_names = [arg_name.value for arg_name in node.arg_name_toks]
func_value = Function(func_name, body_node, arg_names).set_context(context).set_pos(node.pos_start, node.pos_end)
if node.var_name_tok:
context.symbol_table.set(func_name, func_value)
return res.success(func_value)
def visit_CallNode(self, node, context):
res = RTResult()
args = []
value_to_call = res.register(self.visit(node.node_to_call, context))
if res.error: return res
value_to_call = value_to_call.copy().set_pos(node.pos_start, node.pos_end)
for arg_node in node.arg_nodes:
args.append(res.register(self.visit(arg_node, context)))
if res.error: return res
return_value = res.register(value_to_call.execute(args))
if res.error: return res
return res.success(return_value)
#############
# Run
#############
global_symbol_table = SymbolTable()
global_symbol_table.set('NaN', Number(0))
global_symbol_table.set('True', Number(1))
global_symbol_table.set('False', Number(0))
def run(fn, text):
# Generate tokens
lexer = Lexer(fn, text)
tokens, error = lexer.make_tokens()
if error: return None, error
# Generate AST
parser = Parser(tokens)
ast = parser.parse()
if ast.error: return None, ast.error
# Run program
interpreter = Interpreter()
context = Context('<program>')
context.symbol_table = global_symbol_table
result = interpreter.visit(ast.node, context)
return result.value, result.error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment