Toy calculator in Lua, version 3
setmetatable(_ENV, { __index=lpeg }) | |
VARS = {} | |
function eval_expr(expr) | |
local accum = eval(expr[2]) -- because 1 is "expr" | |
for i = 3, #expr, 2 do | |
local operator = expr[i] | |
local num2 = eval(expr[i+1]) | |
if operator == '+' then | |
accum = accum + num2 | |
elseif operator == '-' then | |
accum = accum - num2 | |
elseif operator == '*' then | |
accum = accum * num2 | |
elseif operator == '/' then | |
accum = accum / num2 | |
end | |
end | |
return accum | |
end | |
function eval_bool(expr) | |
local num1 = eval(expr[2]) | |
local operator = expr[3] | |
local num2 = eval(expr[4]) | |
if operator == '<' then | |
return num1 < num2 | |
elseif operator == '<=' then | |
return num1 <= num2 | |
elseif operator == '>' then | |
return num1 > num2 | |
elseif operator == '>=' then | |
return num1 >= num2 | |
elseif operator == '==' then | |
return num1 == num2 | |
elseif operator == '!=' then | |
return num1 ~= num2 | |
end | |
end | |
function eval(ast) | |
if type(ast) == 'number' then | |
return ast | |
elseif ast[1] == 'expr' or ast[1] == 'term' then | |
return eval_expr(ast) | |
elseif ast[1] == 'array' then | |
local new = {} | |
for _, el in ipairs(ast[2]) do | |
table.insert(new, eval(el)) | |
end | |
return new | |
elseif ast[1] == 'ref' then | |
return lookup(ast) | |
elseif ast[1] == 'assign' then | |
return assign(ast[2], eval(ast[3])) | |
elseif ast[1] == 'list' then | |
for i = 2, #ast do | |
eval(ast[i]) | |
end | |
elseif ast[1] == 'if' then | |
if eval_bool(ast[2]) then | |
return eval(ast[3]) | |
end | |
elseif ast[1] == 'while' then | |
while eval_bool(ast[2]) do | |
eval(ast[3]) | |
end | |
end | |
end | |
function assign(ref, value) | |
local current = VARS | |
for i = 2, #ref do | |
local next_index = ref[i] | |
if type(next_index) == 'table' then | |
next_index = eval(next_index) | |
end | |
if i == #ref then -- last one, set the value | |
current[next_index] = value | |
return value | |
else -- not the last, keep following the chain | |
current = current[next_index] | |
end | |
end | |
end | |
function lookup(ref) | |
local current = VARS | |
for i = 2, #ref do | |
local next_index = ref[i] | |
if type(next_index) == 'table' then | |
next_index = eval(next_index) | |
end | |
current = current[next_index] | |
end | |
return current | |
end | |
spc = S(" \t\n")^0 | |
digit = R('09') | |
number = C( (P("-") + digit) * | |
digit^0 * | |
( P('.') * digit^0 )^-1 ) / tonumber * spc | |
lparen = "(" * spc | |
rparen = ")" * spc | |
lbrack = "[" * spc | |
rbrack = "]" * spc | |
lcurly = "{" * spc | |
rcurly = "}" * spc | |
comma = "," * spc | |
expr_op = C( S('+-') ) * spc | |
term_op = C( S('*/') ) * spc | |
letter = R('AZ','az') | |
name = C( letter * (digit+letter+"_")^0 ) * spc | |
keywords = (P("if")+P("while")) * spc | |
name = name - keywords | |
boolean = C( S("<>") + "<=" + ">=" + "!=" + "==" ) * spc | |
stmt = spc * P{ | |
"LIST"; | |
LIST = | |
V("STMT") + | |
Ct( Cc("list") * | |
lcurly * | |
V("STMT") * ( ";" * spc * V("STMT") )^0 * | |
rcurly ), | |
STMT = | |
Ct( Cc("assign") * V("REF") * "=" * spc * V("VAL") ) + | |
V("EXPR") + | |
V("IF") + | |
V("WHILE"), | |
EXPR = Ct( Cc("expr") * V("TERM") * ( expr_op * V("TERM") )^0 ), | |
TERM = Ct( Cc("term") * V("FACT") * ( term_op * V("FACT") )^0 ), | |
REF = Ct( Cc("ref") * name * (lbrack * V("EXPR") * rbrack)^0 ), | |
FACT = | |
number + | |
lparen * V("EXPR") * rparen + | |
V("REF"), | |
ARRAY = Ct( Cc("array") * lbrack * Ct( V("VAL_LIST")^-1 ) * rbrack ), | |
VAL_LIST = V("VAL") * (comma * V("VAL"))^0, | |
VAL = V("EXPR") + V("ARRAY"), | |
BOOL = Ct( Cc("bool") * V("EXPR") * boolean * V("EXPR") ), | |
IF = Ct( C("if") * spc * lparen * V("BOOL") * rparen * V("LIST") ), | |
WHILE = Ct( C("while") * spc * lparen * V("BOOL") * rparen * V("LIST") ) | |
} | |
function test(stmt) | |
stmt = stmt / eval | |
assert(stmt:match(" 1 + 2 ") == 3) | |
assert(stmt:match("1+2+3+4+5") == 15) | |
assert(stmt:match("2*3*4 + 5*6*7") == 234) | |
assert(stmt:match(" 1 * 2 + 3") == 5) | |
assert(stmt:match("( 2 +2) *6") == 24) | |
stmt:match("a=3"); assert(VARS.a == 3) | |
assert(stmt:match("a") == 3) | |
assert(stmt:match("a * 5") == 15); VARS.a=nil | |
stmt:match("a = [ 4, 5, 6 ]"); | |
assert(VARS.a[1] == 4) | |
assert(VARS.a[2] == 5) | |
assert(VARS.a[3] == 6) | |
VARS.a=nil | |
stmt:match("b = [ ]"); | |
assert(VARS.b[1] == nil) | |
VARS.b=nil | |
stmt:match("c = [[1,2], [3,4]]") | |
assert(VARS.c[1][1] == 1) | |
assert(VARS.c[1][2] == 2) | |
assert(VARS.c[2][1] == 3) | |
assert(VARS.c[2][2] == 4) | |
assert(stmt:match("c[4/2][1]") == 3) | |
stmt:match("c[3] = 5") | |
assert(VARS.c[3] == 5) | |
VARS.c=nil | |
stmt:match("if(1 < 0) b = 5"); assert(VARS.b ~= 5) | |
VARS.n=0; VARS.x=1 | |
stmt:match("while(n < 8) { x = x * 2; n = n + 1 }") | |
assert(VARS.x == 256) | |
VARS.n=nil; VARS.x=nil | |
end | |
function repl(file) | |
file = file or io.input() | |
parser = stmt | |
for line in file:lines() do | |
print(parser:match(line)) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment