metalua/treequery/walk.mlua
-{ extension "match" } | |
local M = { traverse = { }; tags = { }; debug = false } | |
-------------------------------------------------------------------------------- | |
-- Standard tags: can be used to guess the type of an AST, or to check | |
-- that the type of an AST is respected. | |
-------------------------------------------------------------------------------- | |
M.tags.stat = table.transpose{ | |
'Do', 'Set', 'While', 'Repeat', 'Local', 'Localrec', 'Return', | |
'Fornum', 'Forin', 'If', 'Break', 'Goto', 'Label', | |
'Call', 'Invoke' } | |
M.tags.expr = table.transpose{ | |
'Paren', 'Call', 'Invoke', 'Index', 'Op', 'Function', 'Stat', | |
'Table', 'Nil', 'Dots', 'True', 'False', 'Number', 'String', 'Id' } | |
-------------------------------------------------------------------------------- | |
-- These [M.traverse.xxx()] functions are in charge of actually going through | |
-- ASTs. At each node, they make sure to call the appropriate walker. | |
-------------------------------------------------------------------------------- | |
function M.traverse.stat (cfg, x, ...) | |
if M.debug then printf("traverse stat %s", table.tostring(x)) end | |
local ancestors = {...} | |
local B = |y| M.block (cfg, y, x, unpack(ancestors)) -- Block | |
local S = |y| M.stat (cfg, y, x, unpack(ancestors)) -- Statement | |
local E = |y| M.expr (cfg, y, x, unpack(ancestors)) -- Expression | |
local EL = |y| M.expr_list (cfg, y, x, unpack(ancestors)) -- Expression List | |
local IL = |y| M.binder_list (cfg, y, x, unpack(ancestors)) -- Id binders List | |
local OS = |y| cfg.scope :save() -- Open scope | |
local CS = |y| cfg.scope :restore() -- Close scope | |
local function BS(y) OS(); B(y); CS() end -- Block with Scope | |
match x with | |
| {...} if x.tag == nil -> for y in ivalues(x) do M.stat(cfg, y, ...) end | |
-- no tag --> node not inserted in the history ancestors | |
| `Do{...} -> BS(x) | |
| `Set{ lhs, rhs } -> EL(lhs); EL(rhs) | |
| `While{ cond, body } -> E(cond); BS(body) | |
| `Repeat{ body, cond } -> OS(body); B(body); E(cond); CS(body) | |
| `Local{ lhs } -> IL(lhs) | |
| `Local{ lhs, rhs } -> EL(rhs); IL(lhs) | |
| `Localrec{ lhs, rhs } -> IL(lhs); EL(rhs) | |
| `Fornum{ i, a, b, body } -> E(a); E(b); IL{i}; BS(body) | |
| `Fornum{ i, a, b, c, body } -> E(a); E(b); E(c); IL{i}; BS(body) | |
| `Forin{ i, rhs, body } -> EL(rhs); IL(i); BS(body) | |
| `If{...} -> for i=1, #x-1, 2 do E(x[i]); BS(x[i+1]) end | |
if #x%2 == 1 then BS(x[#x]) end | |
| `Call{...}|`Invoke{...}|`Return{...} -> EL(x) | |
| `Break | `Goto{ _ } | `Label{ _ } -> -- nothing | |
| { tag=tag, ...} if M.tags.stat[tag]-> | |
M.malformed (cfg, x, unpack (ancestors)) | |
| _ -> | |
M.unknown (cfg, x, unpack (ancestors)) | |
end | |
end | |
function M.traverse.expr (cfg, x, ...) | |
if M.debug then printf("traverse expr %s", table.tostring(x)) end | |
local ancestors = {...} | |
local B = |y| M.block (cfg, y, x, unpack(ancestors)) -- Block | |
local S = |y| M.stat (cfg, y, x, unpack(ancestors)) -- Statement | |
local E = |y| M.expr (cfg, y, x, unpack(ancestors)) -- Expression | |
local EL = |y| M.expr_list (cfg, y, x, unpack(ancestors)) -- Expression List | |
local IL = |y| M.binder_list (cfg, y, x, unpack(ancestors)) -- Id binders list | |
local OS = |y| cfg.scope :save() -- Open scope | |
local CS = |y| cfg.scope :restore() -- Close scope | |
local function BS(y) OS(); B(y); CS() end -- Block with Scope | |
match x with | |
| `Paren{ e } -> E(e) | |
| `Call{...} | `Invoke{...} -> EL(x) | |
| `Index{ a, b } -> E(a); E(b) | |
| `Op{ opid, ... } -> E(x[2]); if #x==3 then E(x[3]) end | |
| `Function{ params, body } -> OS(body); IL(params); B(body); CS(body) | |
| `Stat{ b, e } -> OS(body); B(b); E(e); CS(body) | |
| `Id{ name } -> M.occurrence(cfg, x, unpack(ancestors)) | |
| `Table{ ... } -> | |
for i = 1, #x do match x[i] with | |
| `Pair{ k, v } -> E(k); E(v) | |
| v -> E(v) | |
end end | |
| `Nil|`Dots|`True|`False|`Number{_}|`String{_} -> -- terminal node | |
| { tag=tag, ...} if M.tags.expr[tag]-> M.malformed (cfg, x, unpack (ancestors)) | |
| _ -> M.unknown (cfg, x, unpack (ancestors)) | |
end | |
end | |
function M.traverse.block (cfg, x, ...) | |
assert(type(x)=='table', "traverse.block() expects a table") | |
for y in ivalues(x) do M.stat(cfg, y, x, ...) end | |
end | |
function M.traverse.expr_list (cfg, x, ...) | |
assert(type(x)=='table', "traverse.expr_list() expects a table") | |
-- x doesn't appear in the ancestors | |
for y in ivalues(x) do M.expr(cfg, y, ...) end | |
end | |
function M.malformed(cfg, x, ...) | |
error ("Malformed node of tag "..(x.tag or '(nil)')) | |
end | |
function M.unknown(cfg, x, ...) | |
error ("Unknown node tag "..(x.tag or '(nil)')) | |
end | |
function M.occurrence(cfg, x, ...) | |
if cfg.occurrence then cfg.occurrence(cfg.scope :get(x[1]), x, ...) end | |
end | |
function M.binder_list (cfg, id_list, ...) | |
local f = cfg.binder | |
for i, id_node in ipairs(id_list) do | |
if id_node.tag == 'Id' then | |
cfg.scope :set (id_node[1], { id_node, ... }) | |
if f then f(id_node, ...) end | |
else assert (i==#id_list and id_node.tag=='Dots', "Invalid binders list") end | |
end | |
end | |
---------------------------------------------------------------------- | |
-- Generic walker generator. | |
-- * if `cfg' has an entry matching the tree name, use this entry | |
-- * if not, try to use the entry whose name matched the ast kind | |
-- * if an entry is a table, look for 'up' and 'down' entries | |
-- * if it is a function, consider it as a `down' traverser. | |
---------------------------------------------------------------------- | |
local walker_builder = |traverse| function (cfg, ...) | |
if not cfg.scope then cfg.scope = M.newscope() end | |
local down, up = cfg.down, cfg.up | |
local broken = down and down(...) | |
if broken ~= 'break' then M.traverse[traverse] (cfg, ...) end | |
if up then up(...) end | |
end | |
---------------------------------------------------------------------- | |
-- Declare [M.stat], [M.expr], [M.block] and [M.expr_list] | |
---------------------------------------------------------------------- | |
for w in values{ "stat", "expr", "block", "malformed", "unknown" } do | |
M[w] = walker_builder (w, M.traverse[w]) | |
end | |
-- Don't call up/down callbacks on expr lists | |
M.expr_list = M.traverse.expr_list | |
---------------------------------------------------------------------- | |
-- Try to guess the type of the AST then choose the right walkker. | |
---------------------------------------------------------------------- | |
function M.guess (cfg, x, ...) | |
assert(type(x)=='table', "arg #2 in a walker must be an AST") | |
if M.tags.expr[x.tag] then return M.expr(cfg, x, ...) end | |
if M.tags.stat[x.tag] then return M.stat(cfg, x, ...) end | |
if not x.tag then return M.block(cfg, x, ...) end | |
error ("Can't guess the AST type from tag "..(x.tag or '<none>')) | |
end | |
local S = { }; S.__index = S | |
function M.newscope() | |
local instance = { current = { } } | |
instance.stack = { instance.current } | |
setmetatable (instance, S) | |
return instance | |
end | |
function S :save(...) | |
table.insert (self.stack, table.shallow_copy (self.current)) | |
if ... then return self :add(...) end | |
end | |
function S :restore() self.current = table.remove (self.stack) end | |
function S :get (var_name) return self.current[var_name] end | |
function S :set (key, val) self.current[key] = val end | |
return M |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment