Skip to content

Instantly share code, notes, and snippets.

@arthuro555
Created November 7, 2021 01:25
Show Gist options
  • Save arthuro555/3db35d1a7a98319a9bcaf6083ab49d0a to your computer and use it in GitHub Desktop.
Save arthuro555/3db35d1a7a98319a9bcaf6083ab49d0a to your computer and use it in GitHub Desktop.
Simple calculator in python
from dataclasses import dataclass
from typing import Callable
@dataclass
class Position:
"""
Represents a portion of a string.
Note that while it is not strictly required to have position information,
it is always something useful to have and trivial to implement in an AST.
"""
start: int
end: int
def __repr__(self):
return f"({self.start};{self.end})"
# AST Nodes
@dataclass(init=False)
class BaseNode:
"""
The base class of all nodes.
Used for typechecking & having a position on all nodes.
"""
position: Position = None
@dataclass(init=False)
class NumberLiteralNode(BaseNode):
"""
Holds a number literal.
Examples: 1, 69, 420, 2021...
"""
number: int
def __init__(self, number, position):
super()
self.number = number
self.position = position
@dataclass(init=False)
class SubExpressionNode(BaseNode):
"""
Holds a sub expression, expressed here with brackets, as in mathematics.
Examples: (1+1), (1), (+(-1))
"""
child: BaseNode = None
@dataclass(init=False)
class UnaryOperatorNode(BaseNode):
"""
Holds an unary operator, an operator with only one value.
In this case, it can be + (no operation) or - (multiplies by -1).
Examples: +1, -2, -(2+2)...
"""
negative: bool = False
child: BaseNode = None
@dataclass(init=False)
class TwoSidedOperatorNode(BaseNode):
"""
Holds a binary operator, an operator that operates with a left hand side and right hand side operation.
Examples: 1+1, 22-41, (22-1)/(22*2)...
"""
lhs: BaseNode = None
operator: str = None
rhs: BaseNode = None
def __init__(self, lhs, operator, rhs):
super()
self.lhs = lhs
self.operator = operator
self.rhs = rhs
# Automatically compute the position from lhs and rhs position
self.position = Position(lhs.position.start, rhs.position.end)
# Parser utility classes
@dataclass(init=False)
class Expression(BaseNode):
"A class that holds an AST, and methods to work with it."
child: BaseNode
expression: str
def __init__(self, child, expression):
self.child = child
self.expression = expression
self.position = Position(0, len(expression)-1)
self.__evaluator = ASTEvaluator()
self.__printer = ASTPrinter()
def print(self):
return self.__printer.print(self)
def eval(self):
"Evaluates the expression, returning the result."
return self.__evaluator.eval(self)
@dataclass
class Error:
"""
A simple parsing error container class.
"""
message: str
position: Position
# Parser predicates - Those make the parser easier to read and tweak ;)
def IsWhiteSpace(string: str) -> bool:
return string == " " or string == "\t"
def IsSubExpressionBegin(string: str) -> bool:
return string == "("
def IsSubExpressionEnd(string: str) -> bool:
return string == ")"
def IsNumberLiteral(string: str) -> bool:
return string == "1" or string == "2" or string == "3" or string == "4" or string == "5" or string == "6" or string == "7" or string == "8" or string == "9"
def IsOperator(string: str) -> bool:
return string == "+" or string == "-" or string == "*" or string == "/"
def IsUnaryOperator(string: str) -> bool:
return string == "+" or string == "-"
def IsInvertOperator(string: str) -> bool:
return string == "-"
def IsPriorityOperator(string: str) -> bool:
return string == "*" or string == "/"
class Parser:
# Initialisation
def parse(self, string: str) -> Expression:
self.string: str = string
self.position: int = 0
self.errors: list[Error] = []
"Parses a string into an Expression."
node = Expression(self.term(), self.string)
if not self.is_end():
self.errors.append(Error("Expression has extra characters after the end of the expression", Position(
self.position, len(self.string)-1)))
return node
# Grammar
def term(self) -> BaseNode:
"Parses a term, one or multiple factors separated by operators."
# First parse a first factor
self.skip_whitespaces()
node = self.factor()
self.skip_whitespaces()
# If there is an operator, we will have to parse all factors and operators,
# and wrap in TwoSidedOperatorNodes in order of importance (multiplications > additions)
if self.check_char(IsOperator):
factors = [node]
operators = []
while self.check_char(IsOperator):
operators.append(self.next())
self.skip_whitespaces()
factors.append(self.factor())
self.skip_whitespaces()
# First convert the priority operators
i = 0
while i < len(operators):
if IsPriorityOperator(operators[i]):
factors.insert(i, TwoSidedOperatorNode(
factors.pop(i), operators.pop(i), factors.pop(i)))
else:
i += 1
# Then do the rest of the operators
while len(operators) != 0:
factors.insert(0, TwoSidedOperatorNode(
factors.pop(0), operators.pop(0), factors.pop(0)))
return factors[0]
# If there are no operators simply return the parsed node
return node
def factor(self) -> BaseNode:
"""
Parses a factor, which can be pretty much anything that is not an operator.
A factor is one of the parameters of a two sided (binary) operation.
"""
if self.check_char(IsSubExpressionBegin):
return self.subexpression()
if self.check_char(IsNumberLiteral):
return self.number()
if self.check_char(IsUnaryOperator):
return self.unary()
def subexpression(self) -> SubExpressionNode:
"Parses a sub expression, everything in brackets ()."
node = SubExpressionNode()
start_pos = self.position
# Skip the (
self.next()
# Parse the subexpression term
node.child = self.term()
node.position = Position(start_pos, self.position)
# Check for syntax error
self.skip_whitespaces()
if not self.check_char(IsSubExpressionEnd):
error_start = self.position
self.skip_until(IsSubExpressionEnd)
self.errors.append(Error("Unexpected end of expression" if self.is_end(
) else "Unrecognized characters at the end of subexpression (did you mean to add an operator?)", Position(error_start, self.position)))
# Skip the )
self.next()
return node
def number(self) -> NumberLiteralNode:
"Parses a number literal, everythings that's literally a number."
start_pos = self.position
number = ""
while self.check_char(IsNumberLiteral):
number += self.next()
return NumberLiteralNode(int(number), Position(start_pos, self.position))
def unary(self):
"Parses an unary operator, either the no-op + or the inverting - operator."
node = UnaryOperatorNode()
node.negative = True if self.check_char(IsInvertOperator) else False
node.position = Position(self.position, self.position + 1)
self.next()
self.skip_whitespaces()
node.child = self.factor()
return node
# Parsing utility functions
def peek(self) -> str:
"Returns the current character."
return self.string[self.position] if not self.is_end() else ""
def next(self) -> str:
"Return the current character and switch to the next."
if self.is_end():
return ""
old = self.peek()
self.position += 1
return old
def skip_while(self, predicate: Callable[[str], bool]):
"Skip characters wihile they meet a criteria."
while(predicate(self.peek())):
self.next()
def skip_until(self, predicate: Callable[[str], bool]):
"Skip characters until getting to one meeting a criteria."
while(not predicate(self.peek())):
self.next()
def skip_whitespaces(self):
"Skip all whitespaces."
self.skip_while(IsWhiteSpace)
def check_char(self, predicate: Callable[[str], bool]):
"Checks if the current character meets a criteria."
return predicate(self.peek())
def is_end(self) -> bool:
"Returns true of the end of the string has been reached."
return len(self.string) == self.position
class ASTVisitor:
def visit(self, node: BaseNode):
if isinstance(node, Expression):
return self.visit_expression(node)
if isinstance(node, NumberLiteralNode):
return self.visit_number_literal(node)
if isinstance(node, TwoSidedOperatorNode):
return self.visit_two_sided_operator(node)
if isinstance(node, UnaryOperatorNode):
return self.visit_unary_operator(node)
if isinstance(node, SubExpressionNode):
return self.visit_subexpression(node)
def visit_expression(self, exp: Expression):
return self.visit(exp.child)
def visit_number_literal(self, node: NumberLiteralNode):
pass
def visit_two_sided_operator(self, node: TwoSidedOperatorNode):
return (self.visit(node.lhs), self.visit(node.rhs))
def visit_unary_operator(self, node: UnaryOperatorNode):
return self.visit(node.child)
def visit_subexpression(self, node: SubExpressionNode):
return self.visit(node.child)
class ASTEvaluator(ASTVisitor):
def eval(self, exp: Expression) -> int:
return self.visit(exp)
def visit_number_literal(self, node: NumberLiteralNode) -> int:
return node.number
def visit_two_sided_operator(self, node: TwoSidedOperatorNode) -> int:
lhs = self.visit(node.lhs)
rhs = self.visit(node.rhs)
if node.operator == "+":
return lhs + rhs
if node.operator == "-":
return lhs - rhs
if node.operator == "*":
return lhs * rhs
if node.operator == "/":
return lhs / rhs
def visit_unary_operator(self, node: UnaryOperatorNode) -> int:
value = self.visit(node.child)
return value * -1 if node.negative else value
class ASTPrinter(ASTVisitor):
def print(self, exp: Expression) -> str:
self.str = ""
self.visit(exp)
return self.str
def visit_number_literal(self, node: NumberLiteralNode):
self.str += str(node.number)
def visit_two_sided_operator(self, node: TwoSidedOperatorNode):
self.visit(node.lhs)
self.str += f" {node.operator} "
self.visit(node.rhs)
def visit_unary_operator(self, node: UnaryOperatorNode):
self.str += "-" if node.negative else "+"
self.visit(node.child)
def visit_subexpression(self, node: SubExpressionNode):
self.str += "("
self.visit(node.child)
self.str += ")"
if __name__ == "__main__":
parser = Parser()
while(True):
print(parser.parse(
input("Enter a mathematical expression to calculate -> ")).eval())
import unittest
from calc import Parser
class ParserResilience(unittest.TestCase):
def test_spaces(self):
parser = Parser()
parser.parse(" - ( 13 - 345 ) / 4 * ( + 6 - 7 ) ")
self.assertEqual(parser.errors, [])
def test_ignore_garbage(self):
self.assertEqual(Parser().parse(
"1+(2+3saa)+2sacacascs").print(), "1 + (2 + 3) + 2")
class ExpressionEvaluation(unittest.TestCase):
def test_single_literal(self):
self.assertEqual(Parser().parse("1").eval(), 1)
def test_addition(self):
self.assertEqual(Parser().parse("1+1").eval(), 2)
def test_multiplication(self):
self.assertEqual(Parser().parse("2*2").eval(), 4)
def test_sub_expression(self):
self.assertEqual(Parser().parse("(1+2)*(3+4)").eval(), 21)
def test_multiplication_before_addition(self):
self.assertEqual(Parser().parse("2*3+1").eval(), 7)
self.assertEqual(Parser().parse("1+2*3").eval(), 7)
self.assertEqual(Parser().parse("2*(3+1)").eval(), 8)
def test_unary_operator(self):
self.assertEqual(Parser().parse("+1").eval(), 1)
self.assertEqual(Parser().parse("1+-1").eval(), 0)
self.assertEqual(Parser().parse("2*-(-1+-1)").eval(), 4)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment