Skip to content

Instantly share code, notes, and snippets.

@derrickturk
Last active July 27, 2022 19:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save derrickturk/3c97593399c008a95d76387af28b49a1 to your computer and use it in GitHub Desktop.
Save derrickturk/3c97593399c008a95d76387af28b49a1 to your computer and use it in GitHub Desktop.
fun with lark and PHDwin
%import common.SIGNED_NUMBER
%import common.CNAME
%import common.WS
%ignore WS
?start: expr
?expr: comp_expr
?comp_expr: addsub_expr [comp_op addsub_expr]
?addsub_expr: muldiv_expr (addsub_op muldiv_expr)*
?muldiv_expr: factor (muldiv_op factor)*
?factor: sign_op factor | atom_expr
// TODO: not equal?
!comp_op: "<" | ">" | "=" | "<=" | ">="
!addsub_op: "+" | "-"
!muldiv_op: "x" | "/"
!sign_op: "+" | "-"
?atom_expr: identifier "(" [args] ")" -> funcall
| atom
// TODO: I see some function calls(?) inside { } - what are these?
?atom: "(" expr ")"
| identifier
| number
?args: expr ("," expr)*
// TODO: we need to know the escaping rules here...
?identifier: (CNAME | "{" /[^}]+/ "}") -> var
?number: SIGNED_NUMBER -> const
import sys
import lark
from enum import Enum, auto
from collections import namedtuple
class UnaryOp(Enum):
Pos = auto()
Neg = auto()
class BinaryOp(Enum):
Add = auto()
Sub = auto()
Mul = auto()
Div = auto()
Lt = auto()
Gt = auto()
LtEq = auto()
GtEq = auto()
Eq = auto()
class Constant(namedtuple('Constant', ['val'])):
__slots__ = ()
def eval(self, var_dict):
return self.val
def pprint(self):
return f'{self.val}'
class Var(namedtuple('Var', ['name'])):
__slots__ = ()
def eval(self, var_dict):
return var_dict[self.name]
def pprint(self):
return f'{{{self.name}}}'
class BinaryOpApply(namedtuple('BinaryOpApply', ['op', 'lhs', 'rhs'])):
__slots__ = ()
def eval(self, var_dict):
lhs_val = self.lhs.eval(var_dict)
rhs_val = self.rhs.eval(var_dict)
if self.op == BinaryOp.Add:
return lhs_val + rhs_val
if self.op == BinaryOp.Sub:
return lhs_val - rhs_val
if self.op == BinaryOp.Mul:
return lhs_val * rhs_val
if self.op == BinaryOp.Div:
return lhs_val / rhs_val
if self.op == BinaryOp.Lt:
return lhs_val < rhs_val
if self.op == BinaryOp.Gt:
return lhs_val > rhs_val
if self.op == BinaryOp.LtEq:
return lhs_val <= rhs_val
if self.op == BinaryOp.GtEq:
return lhs_val >= rhs_val
if self.op == BinaryOp.Eq:
return lhs_val == rhs_val
raise ArgumentError('invalid binary operator')
def pprint(self):
lhs_pp = self.lhs.pprint()
rhs_pp = self.rhs.pprint()
if self.op == BinaryOp.Add:
return f'({lhs_pp} + {rhs_pp})'
if self.op == BinaryOp.Sub:
return f'({lhs_pp} - {rhs_pp})'
if self.op == BinaryOp.Mul:
return f'({lhs_pp} x {rhs_pp})'
if self.op == BinaryOp.Div:
return f'({lhs_pp} / {rhs_pp})'
if self.op == BinaryOp.Lt:
return f'({lhs_pp} < {rhs_pp})'
if self.op == BinaryOp.Gt:
return f'({lhs_pp} > {rhs_pp})'
if self.op == BinaryOp.LtEq:
return f'({lhs_pp} <= {rhs_pp})'
if self.op == BinaryOp.GtEq:
return f'({lhs_pp} >= {rhs_pp})'
if self.op == BinaryOp.Eq:
return f'({lhs_pp} = {rhs_pp})'
raise ArgumentError('invalid binary operator')
class UnaryOpApply(namedtuple('UnaryOpApply', ['op', 'expr'])):
__slots__ = ()
def eval(self, var_dict):
expr_val = self.expr.eval(var_dict)
if self.op == UnaryOp.Pos:
return expr_val
if self.op == UnaryOp.Neg:
return -expr_val
raise ArgumentError('invalid unary operator')
def pprint(self):
expr_pp = self.expr.pprint()
if self.op == UnaryOp.Pos:
return f'+{expr_pp}'
if self.op == UnaryOp.Neg:
return f'-{expr_pp}'
raise ArgumentError('invalid unary operator')
class FunCall(namedtuple('FunCall', ['fn', 'args'])):
__slots__ = ()
def eval(self, var_dict):
if self.fn == 'If':
# if has special rules - it only evaluates one or the
# other of its arguments!
cond, do_if, do_else = self.args
if cond.eval(var_dict):
return do_if.eval(var_dict)
return do_else.eval(var_dict)
if self.fn == 'Abs':
val, = self.args
return abs(val.eval(var_dict))
raise ArgumentError(f'function {self.fn} not yet implemented!')
def pprint(self):
args = ', '.join(a.pprint() for a in self.args)
return f'{self.fn}({args})'
class ASTBuilder(lark.Transformer):
def var(self, args):
return Var(name=args[0].value)
def const(self, args):
return Constant(float(args[0]))
def sign_op(self, args):
if args[0] == '-':
return UnaryOp.Neg
elif args[0] == '+':
return UnaryOp.Pos
else:
raise ArgumentError('invalid unary operator')
def addsub_op(self, args):
if args[0] == '+':
return BinaryOp.Add
elif args[0] == '-':
return BinaryOp.Sub
else:
raise ArgumentError('invalid add/sub operator')
def muldiv_op(self, args):
if args[0] == 'x':
return BinaryOp.Mul
elif args[0] == '/':
return BinaryOp.Div
else:
raise ArgumentError('invalid mul/div operator')
def comp_op(self, args):
if args[0] == '<':
return BinaryOp.Lt
elif args[0] == '>':
return BinaryOp.Gt
elif args[0] == '<=':
return BinaryOp.LtEq
elif args[0] == '>=':
return BinaryOp.GtEq
elif args[0] == '=':
return BinaryOp.Eq
else:
raise ArgumentError('invalid comparison operator')
def comp_expr(self, args):
lhs, op, rhs = args
# either just an expression
if op is None and rhs is None:
return lhs
# or a comparison between two expressions
return BinaryOpApply(op=op, lhs=lhs, rhs=rhs)
def addsub_expr(self, args):
# either just an expression
if len(args) == 1:
return args[0]
# or a sequence of left-associative operations
# (this looks a little crazy, but it's going to turn
# x + y + z into (x + y) + z and so on)
ex, *args = args
while len(args) > 0:
op, next_ex, *args = args
ex = BinaryOpApply(op, lhs=ex, rhs=next_ex)
return ex
def muldiv_expr(self, args):
# same logic as addsub_expr
if len(args) == 1:
return args[0]
ex, *args = args
while len(args) > 0:
op, next_ex, *args = args
ex = BinaryOpApply(op=op, lhs=ex, rhs=next_ex)
return ex
def funcall(self, args):
fn, fnargs = args
# we have to account for the way Lark handles [args], and ensure
# that we always end up with a list
if fnargs is None:
fnargs = []
elif not isinstance(fnargs, list):
fnargs = [fnargs]
# fn will be a Var, having been already transformed;
# however, functions can only have certain hard-coded names,
# so we just want the string
return FunCall(fn=fn.name, args=fnargs)
def args(self, args):
# function args should just be in-lined into the FunCall tuple
return args
def main(argv):
with open('formula.lark') as f:
p = lark.Lark(f, parser='lalr', transformer=ASTBuilder())
for l in sys.stdin:
tree = p.parse(l)
print(tree)
print(tree.eval({}))
print(tree.pprint())
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv))
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment