Toy calculator in Lua, version 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
setmetatable(_ENV, { __index=lpeg }) | |
VARS = {} | |
function eval_expr(expr) | |
local accum = eval(expr[2]) -- because 1 is "expr" | |
for i = 3, #expr, 2 do | |
local operator = expr[i] | |
local num2 = eval(expr[i+1]) | |
if operator == '+' then | |
accum = accum + num2 | |
elseif operator == '-' then | |
accum = accum - num2 | |
elseif operator == '*' then | |
accum = accum * num2 | |
elseif operator == '/' then | |
accum = accum / num2 | |
end | |
end | |
return accum | |
end | |
function eval_bool(expr) | |
local num1 = eval(expr[2]) | |
local operator = expr[3] | |
local num2 = eval(expr[4]) | |
if operator == '<' then | |
return num1 < num2 | |
elseif operator == '<=' then | |
return num1 <= num2 | |
elseif operator == '>' then | |
return num1 > num2 | |
elseif operator == '>=' then | |
return num1 >= num2 | |
elseif operator == '==' then | |
return num1 == num2 | |
elseif operator == '!=' then | |
return num1 ~= num2 | |
end | |
end | |
function eval(ast) | |
if type(ast) == 'number' then | |
return ast | |
elseif ast[1] == 'expr' or ast[1] == 'term' then | |
return eval_expr(ast) | |
elseif ast[1] == 'array' then | |
local new = {} | |
for _, el in ipairs(ast[2]) do | |
table.insert(new, eval(el)) | |
end | |
return new | |
elseif ast[1] == 'ref' then | |
return lookup(ast) | |
elseif ast[1] == 'assign' then | |
return assign(ast[2], eval(ast[3])) | |
elseif ast[1] == 'list' then | |
for i = 2, #ast do | |
eval(ast[i]) | |
end | |
elseif ast[1] == 'if' then | |
if eval_bool(ast[2]) then | |
return eval(ast[3]) | |
end | |
elseif ast[1] == 'while' then | |
while eval_bool(ast[2]) do | |
eval(ast[3]) | |
end | |
end | |
end | |
function assign(ref, value) | |
local current = VARS | |
for i = 2, #ref do | |
local next_index = ref[i] | |
if type(next_index) == 'table' then | |
next_index = eval(next_index) | |
end | |
if i == #ref then -- last one, set the value | |
current[next_index] = value | |
return value | |
else -- not the last, keep following the chain | |
current = current[next_index] | |
end | |
end | |
end | |
function lookup(ref) | |
local current = VARS | |
for i = 2, #ref do | |
local next_index = ref[i] | |
if type(next_index) == 'table' then | |
next_index = eval(next_index) | |
end | |
current = current[next_index] | |
end | |
return current | |
end | |
spc = S(" \t\n")^0 | |
digit = R('09') | |
number = C( (P("-") + digit) * | |
digit^0 * | |
( P('.') * digit^0 )^-1 ) / tonumber * spc | |
lparen = "(" * spc | |
rparen = ")" * spc | |
lbrack = "[" * spc | |
rbrack = "]" * spc | |
lcurly = "{" * spc | |
rcurly = "}" * spc | |
comma = "," * spc | |
expr_op = C( S('+-') ) * spc | |
term_op = C( S('*/') ) * spc | |
letter = R('AZ','az') | |
name = C( letter * (digit+letter+"_")^0 ) * spc | |
keywords = (P("if")+P("while")) * spc | |
name = name - keywords | |
boolean = C( S("<>") + "<=" + ">=" + "!=" + "==" ) * spc | |
stmt = spc * P{ | |
"LIST"; | |
LIST = | |
V("STMT") + | |
Ct( Cc("list") * | |
lcurly * | |
V("STMT") * ( ";" * spc * V("STMT") )^0 * | |
rcurly ), | |
STMT = | |
Ct( Cc("assign") * V("REF") * "=" * spc * V("VAL") ) + | |
V("EXPR") + | |
V("IF") + | |
V("WHILE"), | |
EXPR = Ct( Cc("expr") * V("TERM") * ( expr_op * V("TERM") )^0 ), | |
TERM = Ct( Cc("term") * V("FACT") * ( term_op * V("FACT") )^0 ), | |
REF = Ct( Cc("ref") * name * (lbrack * V("EXPR") * rbrack)^0 ), | |
FACT = | |
number + | |
lparen * V("EXPR") * rparen + | |
V("REF"), | |
ARRAY = Ct( Cc("array") * lbrack * Ct( V("VAL_LIST")^-1 ) * rbrack ), | |
VAL_LIST = V("VAL") * (comma * V("VAL"))^0, | |
VAL = V("EXPR") + V("ARRAY"), | |
BOOL = Ct( Cc("bool") * V("EXPR") * boolean * V("EXPR") ), | |
IF = Ct( C("if") * spc * lparen * V("BOOL") * rparen * V("LIST") ), | |
WHILE = Ct( C("while") * spc * lparen * V("BOOL") * rparen * V("LIST") ) | |
} | |
function test(stmt) | |
stmt = stmt / eval | |
assert(stmt:match(" 1 + 2 ") == 3) | |
assert(stmt:match("1+2+3+4+5") == 15) | |
assert(stmt:match("2*3*4 + 5*6*7") == 234) | |
assert(stmt:match(" 1 * 2 + 3") == 5) | |
assert(stmt:match("( 2 +2) *6") == 24) | |
stmt:match("a=3"); assert(VARS.a == 3) | |
assert(stmt:match("a") == 3) | |
assert(stmt:match("a * 5") == 15); VARS.a=nil | |
stmt:match("a = [ 4, 5, 6 ]"); | |
assert(VARS.a[1] == 4) | |
assert(VARS.a[2] == 5) | |
assert(VARS.a[3] == 6) | |
VARS.a=nil | |
stmt:match("b = [ ]"); | |
assert(VARS.b[1] == nil) | |
VARS.b=nil | |
stmt:match("c = [[1,2], [3,4]]") | |
assert(VARS.c[1][1] == 1) | |
assert(VARS.c[1][2] == 2) | |
assert(VARS.c[2][1] == 3) | |
assert(VARS.c[2][2] == 4) | |
assert(stmt:match("c[4/2][1]") == 3) | |
stmt:match("c[3] = 5") | |
assert(VARS.c[3] == 5) | |
VARS.c=nil | |
stmt:match("if(1 < 0) b = 5"); assert(VARS.b ~= 5) | |
VARS.n=0; VARS.x=1 | |
stmt:match("while(n < 8) { x = x * 2; n = n + 1 }") | |
assert(VARS.x == 256) | |
VARS.n=nil; VARS.x=nil | |
end | |
function repl(file) | |
file = file or io.input() | |
parser = stmt | |
for line in file:lines() do | |
print(parser:match(line)) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment