Last active
July 27, 2022 19:16
-
-
Save derrickturk/3c97593399c008a95d76387af28b49a1 to your computer and use it in GitHub Desktop.
fun with lark and PHDwin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%import common.SIGNED_NUMBER | |
%import common.CNAME | |
%import common.WS | |
%ignore WS | |
?start: expr | |
?expr: comp_expr | |
?comp_expr: addsub_expr [comp_op addsub_expr] | |
?addsub_expr: muldiv_expr (addsub_op muldiv_expr)* | |
?muldiv_expr: factor (muldiv_op factor)* | |
?factor: sign_op factor | atom_expr | |
// TODO: not equal? | |
!comp_op: "<" | ">" | "=" | "<=" | ">=" | |
!addsub_op: "+" | "-" | |
!muldiv_op: "x" | "/" | |
!sign_op: "+" | "-" | |
?atom_expr: identifier "(" [args] ")" -> funcall | |
| atom | |
// TODO: I see some function calls(?) inside { } - what are these? | |
?atom: "(" expr ")" | |
| identifier | |
| number | |
?args: expr ("," expr)* | |
// TODO: we need to know the escaping rules here... | |
?identifier: (CNAME | "{" /[^}]+/ "}") -> var | |
?number: SIGNED_NUMBER -> const |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import lark | |
from enum import Enum, auto | |
from collections import namedtuple | |
class UnaryOp(Enum): | |
Pos = auto() | |
Neg = auto() | |
class BinaryOp(Enum): | |
Add = auto() | |
Sub = auto() | |
Mul = auto() | |
Div = auto() | |
Lt = auto() | |
Gt = auto() | |
LtEq = auto() | |
GtEq = auto() | |
Eq = auto() | |
class Constant(namedtuple('Constant', ['val'])): | |
__slots__ = () | |
def eval(self, var_dict): | |
return self.val | |
def pprint(self): | |
return f'{self.val}' | |
class Var(namedtuple('Var', ['name'])): | |
__slots__ = () | |
def eval(self, var_dict): | |
return var_dict[self.name] | |
def pprint(self): | |
return f'{{{self.name}}}' | |
class BinaryOpApply(namedtuple('BinaryOpApply', ['op', 'lhs', 'rhs'])): | |
__slots__ = () | |
def eval(self, var_dict): | |
lhs_val = self.lhs.eval(var_dict) | |
rhs_val = self.rhs.eval(var_dict) | |
if self.op == BinaryOp.Add: | |
return lhs_val + rhs_val | |
if self.op == BinaryOp.Sub: | |
return lhs_val - rhs_val | |
if self.op == BinaryOp.Mul: | |
return lhs_val * rhs_val | |
if self.op == BinaryOp.Div: | |
return lhs_val / rhs_val | |
if self.op == BinaryOp.Lt: | |
return lhs_val < rhs_val | |
if self.op == BinaryOp.Gt: | |
return lhs_val > rhs_val | |
if self.op == BinaryOp.LtEq: | |
return lhs_val <= rhs_val | |
if self.op == BinaryOp.GtEq: | |
return lhs_val >= rhs_val | |
if self.op == BinaryOp.Eq: | |
return lhs_val == rhs_val | |
raise ArgumentError('invalid binary operator') | |
def pprint(self): | |
lhs_pp = self.lhs.pprint() | |
rhs_pp = self.rhs.pprint() | |
if self.op == BinaryOp.Add: | |
return f'({lhs_pp} + {rhs_pp})' | |
if self.op == BinaryOp.Sub: | |
return f'({lhs_pp} - {rhs_pp})' | |
if self.op == BinaryOp.Mul: | |
return f'({lhs_pp} x {rhs_pp})' | |
if self.op == BinaryOp.Div: | |
return f'({lhs_pp} / {rhs_pp})' | |
if self.op == BinaryOp.Lt: | |
return f'({lhs_pp} < {rhs_pp})' | |
if self.op == BinaryOp.Gt: | |
return f'({lhs_pp} > {rhs_pp})' | |
if self.op == BinaryOp.LtEq: | |
return f'({lhs_pp} <= {rhs_pp})' | |
if self.op == BinaryOp.GtEq: | |
return f'({lhs_pp} >= {rhs_pp})' | |
if self.op == BinaryOp.Eq: | |
return f'({lhs_pp} = {rhs_pp})' | |
raise ArgumentError('invalid binary operator') | |
class UnaryOpApply(namedtuple('UnaryOpApply', ['op', 'expr'])): | |
__slots__ = () | |
def eval(self, var_dict): | |
expr_val = self.expr.eval(var_dict) | |
if self.op == UnaryOp.Pos: | |
return expr_val | |
if self.op == UnaryOp.Neg: | |
return -expr_val | |
raise ArgumentError('invalid unary operator') | |
def pprint(self): | |
expr_pp = self.expr.pprint() | |
if self.op == UnaryOp.Pos: | |
return f'+{expr_pp}' | |
if self.op == UnaryOp.Neg: | |
return f'-{expr_pp}' | |
raise ArgumentError('invalid unary operator') | |
class FunCall(namedtuple('FunCall', ['fn', 'args'])): | |
__slots__ = () | |
def eval(self, var_dict): | |
if self.fn == 'If': | |
# if has special rules - it only evaluates one or the | |
# other of its arguments! | |
cond, do_if, do_else = self.args | |
if cond.eval(var_dict): | |
return do_if.eval(var_dict) | |
return do_else.eval(var_dict) | |
if self.fn == 'Abs': | |
val, = self.args | |
return abs(val.eval(var_dict)) | |
raise ArgumentError(f'function {self.fn} not yet implemented!') | |
def pprint(self): | |
args = ', '.join(a.pprint() for a in self.args) | |
return f'{self.fn}({args})' | |
class ASTBuilder(lark.Transformer): | |
def var(self, args): | |
return Var(name=args[0].value) | |
def const(self, args): | |
return Constant(float(args[0])) | |
def sign_op(self, args): | |
if args[0] == '-': | |
return UnaryOp.Neg | |
elif args[0] == '+': | |
return UnaryOp.Pos | |
else: | |
raise ArgumentError('invalid unary operator') | |
def addsub_op(self, args): | |
if args[0] == '+': | |
return BinaryOp.Add | |
elif args[0] == '-': | |
return BinaryOp.Sub | |
else: | |
raise ArgumentError('invalid add/sub operator') | |
def muldiv_op(self, args): | |
if args[0] == 'x': | |
return BinaryOp.Mul | |
elif args[0] == '/': | |
return BinaryOp.Div | |
else: | |
raise ArgumentError('invalid mul/div operator') | |
def comp_op(self, args): | |
if args[0] == '<': | |
return BinaryOp.Lt | |
elif args[0] == '>': | |
return BinaryOp.Gt | |
elif args[0] == '<=': | |
return BinaryOp.LtEq | |
elif args[0] == '>=': | |
return BinaryOp.GtEq | |
elif args[0] == '=': | |
return BinaryOp.Eq | |
else: | |
raise ArgumentError('invalid comparison operator') | |
def comp_expr(self, args): | |
lhs, op, rhs = args | |
# either just an expression | |
if op is None and rhs is None: | |
return lhs | |
# or a comparison between two expressions | |
return BinaryOpApply(op=op, lhs=lhs, rhs=rhs) | |
def addsub_expr(self, args): | |
# either just an expression | |
if len(args) == 1: | |
return args[0] | |
# or a sequence of left-associative operations | |
# (this looks a little crazy, but it's going to turn | |
# x + y + z into (x + y) + z and so on) | |
ex, *args = args | |
while len(args) > 0: | |
op, next_ex, *args = args | |
ex = BinaryOpApply(op, lhs=ex, rhs=next_ex) | |
return ex | |
def muldiv_expr(self, args): | |
# same logic as addsub_expr | |
if len(args) == 1: | |
return args[0] | |
ex, *args = args | |
while len(args) > 0: | |
op, next_ex, *args = args | |
ex = BinaryOpApply(op=op, lhs=ex, rhs=next_ex) | |
return ex | |
def funcall(self, args): | |
fn, fnargs = args | |
# we have to account for the way Lark handles [args], and ensure | |
# that we always end up with a list | |
if fnargs is None: | |
fnargs = [] | |
elif not isinstance(fnargs, list): | |
fnargs = [fnargs] | |
# fn will be a Var, having been already transformed; | |
# however, functions can only have certain hard-coded names, | |
# so we just want the string | |
return FunCall(fn=fn.name, args=fnargs) | |
def args(self, args): | |
# function args should just be in-lined into the FunCall tuple | |
return args | |
def main(argv): | |
with open('formula.lark') as f: | |
p = lark.Lark(f, parser='lalr', transformer=ASTBuilder()) | |
for l in sys.stdin: | |
tree = p.parse(l) | |
print(tree) | |
print(tree.eval({})) | |
print(tree.pprint()) | |
return 0 | |
if __name__ == '__main__': | |
sys.exit(main(sys.argv)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment