Created
November 7, 2021 01:25
-
-
Save arthuro555/3db35d1a7a98319a9bcaf6083ab49d0a to your computer and use it in GitHub Desktop.
Simple calculator in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import Callable | |
@dataclass | |
class Position: | |
""" | |
Represents a portion of a string. | |
Note that while it is not strictly required to have position information, | |
it is always something useful to have and trivial to implement in an AST. | |
""" | |
start: int | |
end: int | |
def __repr__(self): | |
return f"({self.start};{self.end})" | |
# AST Nodes | |
@dataclass(init=False) | |
class BaseNode: | |
""" | |
The base class of all nodes. | |
Used for typechecking & having a position on all nodes. | |
""" | |
position: Position = None | |
@dataclass(init=False) | |
class NumberLiteralNode(BaseNode): | |
""" | |
Holds a number literal. | |
Examples: 1, 69, 420, 2021... | |
""" | |
number: int | |
def __init__(self, number, position): | |
super() | |
self.number = number | |
self.position = position | |
@dataclass(init=False) | |
class SubExpressionNode(BaseNode): | |
""" | |
Holds a sub expression, expressed here with brackets, as in mathematics. | |
Examples: (1+1), (1), (+(-1)) | |
""" | |
child: BaseNode = None | |
@dataclass(init=False) | |
class UnaryOperatorNode(BaseNode): | |
""" | |
Holds an unary operator, an operator with only one value. | |
In this case, it can be + (no operation) or - (multiplies by -1). | |
Examples: +1, -2, -(2+2)... | |
""" | |
negative: bool = False | |
child: BaseNode = None | |
@dataclass(init=False) | |
class TwoSidedOperatorNode(BaseNode): | |
""" | |
Holds a binary operator, an operator that operates with a left hand side and right hand side operation. | |
Examples: 1+1, 22-41, (22-1)/(22*2)... | |
""" | |
lhs: BaseNode = None | |
operator: str = None | |
rhs: BaseNode = None | |
def __init__(self, lhs, operator, rhs): | |
super() | |
self.lhs = lhs | |
self.operator = operator | |
self.rhs = rhs | |
# Automatically compute the position from lhs and rhs position | |
self.position = Position(lhs.position.start, rhs.position.end) | |
# Parser utility classes | |
@dataclass(init=False) | |
class Expression(BaseNode): | |
"A class that holds an AST, and methods to work with it." | |
child: BaseNode | |
expression: str | |
def __init__(self, child, expression): | |
self.child = child | |
self.expression = expression | |
self.position = Position(0, len(expression)-1) | |
self.__evaluator = ASTEvaluator() | |
self.__printer = ASTPrinter() | |
def print(self): | |
return self.__printer.print(self) | |
def eval(self): | |
"Evaluates the expression, returning the result." | |
return self.__evaluator.eval(self) | |
@dataclass | |
class Error: | |
""" | |
A simple parsing error container class. | |
""" | |
message: str | |
position: Position | |
# Parser predicates - Those make the parser easier to read and tweak ;) | |
def IsWhiteSpace(string: str) -> bool: | |
return string == " " or string == "\t" | |
def IsSubExpressionBegin(string: str) -> bool: | |
return string == "(" | |
def IsSubExpressionEnd(string: str) -> bool: | |
return string == ")" | |
def IsNumberLiteral(string: str) -> bool: | |
return string == "1" or string == "2" or string == "3" or string == "4" or string == "5" or string == "6" or string == "7" or string == "8" or string == "9" | |
def IsOperator(string: str) -> bool: | |
return string == "+" or string == "-" or string == "*" or string == "/" | |
def IsUnaryOperator(string: str) -> bool: | |
return string == "+" or string == "-" | |
def IsInvertOperator(string: str) -> bool: | |
return string == "-" | |
def IsPriorityOperator(string: str) -> bool: | |
return string == "*" or string == "/" | |
class Parser: | |
# Initialisation | |
def parse(self, string: str) -> Expression: | |
self.string: str = string | |
self.position: int = 0 | |
self.errors: list[Error] = [] | |
"Parses a string into an Expression." | |
node = Expression(self.term(), self.string) | |
if not self.is_end(): | |
self.errors.append(Error("Expression has extra characters after the end of the expression", Position( | |
self.position, len(self.string)-1))) | |
return node | |
# Grammar | |
def term(self) -> BaseNode: | |
"Parses a term, one or multiple factors separated by operators." | |
# First parse a first factor | |
self.skip_whitespaces() | |
node = self.factor() | |
self.skip_whitespaces() | |
# If there is an operator, we will have to parse all factors and operators, | |
# and wrap in TwoSidedOperatorNodes in order of importance (multiplications > additions) | |
if self.check_char(IsOperator): | |
factors = [node] | |
operators = [] | |
while self.check_char(IsOperator): | |
operators.append(self.next()) | |
self.skip_whitespaces() | |
factors.append(self.factor()) | |
self.skip_whitespaces() | |
# First convert the priority operators | |
i = 0 | |
while i < len(operators): | |
if IsPriorityOperator(operators[i]): | |
factors.insert(i, TwoSidedOperatorNode( | |
factors.pop(i), operators.pop(i), factors.pop(i))) | |
else: | |
i += 1 | |
# Then do the rest of the operators | |
while len(operators) != 0: | |
factors.insert(0, TwoSidedOperatorNode( | |
factors.pop(0), operators.pop(0), factors.pop(0))) | |
return factors[0] | |
# If there are no operators simply return the parsed node | |
return node | |
def factor(self) -> BaseNode: | |
""" | |
Parses a factor, which can be pretty much anything that is not an operator. | |
A factor is one of the parameters of a two sided (binary) operation. | |
""" | |
if self.check_char(IsSubExpressionBegin): | |
return self.subexpression() | |
if self.check_char(IsNumberLiteral): | |
return self.number() | |
if self.check_char(IsUnaryOperator): | |
return self.unary() | |
def subexpression(self) -> SubExpressionNode: | |
"Parses a sub expression, everything in brackets ()." | |
node = SubExpressionNode() | |
start_pos = self.position | |
# Skip the ( | |
self.next() | |
# Parse the subexpression term | |
node.child = self.term() | |
node.position = Position(start_pos, self.position) | |
# Check for syntax error | |
self.skip_whitespaces() | |
if not self.check_char(IsSubExpressionEnd): | |
error_start = self.position | |
self.skip_until(IsSubExpressionEnd) | |
self.errors.append(Error("Unexpected end of expression" if self.is_end( | |
) else "Unrecognized characters at the end of subexpression (did you mean to add an operator?)", Position(error_start, self.position))) | |
# Skip the ) | |
self.next() | |
return node | |
def number(self) -> NumberLiteralNode: | |
"Parses a number literal, everythings that's literally a number." | |
start_pos = self.position | |
number = "" | |
while self.check_char(IsNumberLiteral): | |
number += self.next() | |
return NumberLiteralNode(int(number), Position(start_pos, self.position)) | |
def unary(self): | |
"Parses an unary operator, either the no-op + or the inverting - operator." | |
node = UnaryOperatorNode() | |
node.negative = True if self.check_char(IsInvertOperator) else False | |
node.position = Position(self.position, self.position + 1) | |
self.next() | |
self.skip_whitespaces() | |
node.child = self.factor() | |
return node | |
# Parsing utility functions | |
def peek(self) -> str: | |
"Returns the current character." | |
return self.string[self.position] if not self.is_end() else "" | |
def next(self) -> str: | |
"Return the current character and switch to the next." | |
if self.is_end(): | |
return "" | |
old = self.peek() | |
self.position += 1 | |
return old | |
def skip_while(self, predicate: Callable[[str], bool]): | |
"Skip characters wihile they meet a criteria." | |
while(predicate(self.peek())): | |
self.next() | |
def skip_until(self, predicate: Callable[[str], bool]): | |
"Skip characters until getting to one meeting a criteria." | |
while(not predicate(self.peek())): | |
self.next() | |
def skip_whitespaces(self): | |
"Skip all whitespaces." | |
self.skip_while(IsWhiteSpace) | |
def check_char(self, predicate: Callable[[str], bool]): | |
"Checks if the current character meets a criteria." | |
return predicate(self.peek()) | |
def is_end(self) -> bool: | |
"Returns true of the end of the string has been reached." | |
return len(self.string) == self.position | |
class ASTVisitor: | |
def visit(self, node: BaseNode): | |
if isinstance(node, Expression): | |
return self.visit_expression(node) | |
if isinstance(node, NumberLiteralNode): | |
return self.visit_number_literal(node) | |
if isinstance(node, TwoSidedOperatorNode): | |
return self.visit_two_sided_operator(node) | |
if isinstance(node, UnaryOperatorNode): | |
return self.visit_unary_operator(node) | |
if isinstance(node, SubExpressionNode): | |
return self.visit_subexpression(node) | |
def visit_expression(self, exp: Expression): | |
return self.visit(exp.child) | |
def visit_number_literal(self, node: NumberLiteralNode): | |
pass | |
def visit_two_sided_operator(self, node: TwoSidedOperatorNode): | |
return (self.visit(node.lhs), self.visit(node.rhs)) | |
def visit_unary_operator(self, node: UnaryOperatorNode): | |
return self.visit(node.child) | |
def visit_subexpression(self, node: SubExpressionNode): | |
return self.visit(node.child) | |
class ASTEvaluator(ASTVisitor): | |
def eval(self, exp: Expression) -> int: | |
return self.visit(exp) | |
def visit_number_literal(self, node: NumberLiteralNode) -> int: | |
return node.number | |
def visit_two_sided_operator(self, node: TwoSidedOperatorNode) -> int: | |
lhs = self.visit(node.lhs) | |
rhs = self.visit(node.rhs) | |
if node.operator == "+": | |
return lhs + rhs | |
if node.operator == "-": | |
return lhs - rhs | |
if node.operator == "*": | |
return lhs * rhs | |
if node.operator == "/": | |
return lhs / rhs | |
def visit_unary_operator(self, node: UnaryOperatorNode) -> int: | |
value = self.visit(node.child) | |
return value * -1 if node.negative else value | |
class ASTPrinter(ASTVisitor): | |
def print(self, exp: Expression) -> str: | |
self.str = "" | |
self.visit(exp) | |
return self.str | |
def visit_number_literal(self, node: NumberLiteralNode): | |
self.str += str(node.number) | |
def visit_two_sided_operator(self, node: TwoSidedOperatorNode): | |
self.visit(node.lhs) | |
self.str += f" {node.operator} " | |
self.visit(node.rhs) | |
def visit_unary_operator(self, node: UnaryOperatorNode): | |
self.str += "-" if node.negative else "+" | |
self.visit(node.child) | |
def visit_subexpression(self, node: SubExpressionNode): | |
self.str += "(" | |
self.visit(node.child) | |
self.str += ")" | |
if __name__ == "__main__": | |
parser = Parser() | |
while(True): | |
print(parser.parse( | |
input("Enter a mathematical expression to calculate -> ")).eval()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from calc import Parser | |
class ParserResilience(unittest.TestCase): | |
def test_spaces(self): | |
parser = Parser() | |
parser.parse(" - ( 13 - 345 ) / 4 * ( + 6 - 7 ) ") | |
self.assertEqual(parser.errors, []) | |
def test_ignore_garbage(self): | |
self.assertEqual(Parser().parse( | |
"1+(2+3saa)+2sacacascs").print(), "1 + (2 + 3) + 2") | |
class ExpressionEvaluation(unittest.TestCase): | |
def test_single_literal(self): | |
self.assertEqual(Parser().parse("1").eval(), 1) | |
def test_addition(self): | |
self.assertEqual(Parser().parse("1+1").eval(), 2) | |
def test_multiplication(self): | |
self.assertEqual(Parser().parse("2*2").eval(), 4) | |
def test_sub_expression(self): | |
self.assertEqual(Parser().parse("(1+2)*(3+4)").eval(), 21) | |
def test_multiplication_before_addition(self): | |
self.assertEqual(Parser().parse("2*3+1").eval(), 7) | |
self.assertEqual(Parser().parse("1+2*3").eval(), 7) | |
self.assertEqual(Parser().parse("2*(3+1)").eval(), 8) | |
def test_unary_operator(self): | |
self.assertEqual(Parser().parse("+1").eval(), 1) | |
self.assertEqual(Parser().parse("1+-1").eval(), 0) | |
self.assertEqual(Parser().parse("2*-(-1+-1)").eval(), 4) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment