Skip to content

Instantly share code, notes, and snippets.

@findstr
Last active September 27, 2019 01:29
Show Gist options
  • Save findstr/2ddd1f434fb752e8ed65c9cb79b2f6c4 to your computer and use it in GitHub Desktop.
Save findstr/2ddd1f434fb752e8ed65c9cb79b2f6c4 to your computer and use it in GitHub Desktop.
Trie for lua
local concat = table.concat
local tree = {}
local function tree_add(str)
local p = tree
for i = 1, #str do
local c = str:byte(i)
local n = p[c]
if not n then
n = {}
p[c] = n
end
p = n
end
p.isleaf = true
end
local function tree_route(str, i)
local p = tree
local greedy = -1
while i <= #str do
local c = str:byte(i)
p = p[c]
if not p then
break
elseif p.isleaf then
greedy = i
end
i = i + 1
end
return greedy
end
local function tree_match(str, func)
local buf
local i = 1
local last = 1
while i <= #str do
local start = i
i = tree_route(str, i)
if i >= start then
if not buf then
buf = {}
end
print("sub:", str:sub(last, start - 1))
buf[#buf + 1] = str:sub(last, start - 1)
buf[#buf + 1] = func(str, start, i)
last = i + 1
else
i = start
end
i = i+1
end
if buf then
buf[#buf + 1] = str:sub(last)
return concat(buf)
else
return str
end
end
local function tree_detect(str, ban)
local i = 1
local count = 0
while i <= #str do
local start = i
i = tree_route(str, i)
if i >= start then
count = count + 1
ban[start] = i
else
i = start
end
i = i+1
end
return count
end
local base = {}
local check = {}
local check_mt = {__index = false}
local base_mt = {__index = false}
local function merge(dst, src)
for k, v in pairs(src) do
dst[k] = v
end
end
local function find_empty(check, n)
while check[n] do
n = n + 1
end
return n
end
local function dat_try_add(p, c)
--make temp buf
local base_tmp = {}
local check_tmp = {}
base_mt.__index = base_tmp
check_mt.__index = check_tmp
setmetatable(base, base_mt)
setmetatable(base, check_mt)
--begin
local b = base[p]
for ck, cc in pairs(c) do
local s = b + ck
local ss = find_empty(check, s)
if ss > s then
return ss - s
end
ss = find_empty(base, ss)
if ss > s then
return ss - s
end
check_tmp[s] = p
base_tmp[s] = 1
end
merge(base, base_tmp)
merge(check, check_tmp)
end
local function dat_add(p, c)
local tmp = {}
local isleaf = c.isleaf
c.isleaf = nil
local b = base[p]
while true do
local off = dat_try_add(p, c)
if not off then
break
end
b = b + off
base[p] = b
end
for k, v in pairs(c) do
assert(k > 0)
tmp[b+k] = v
end
c.isleaf = isleaf
if isleaf then
local chk = check[p]
assert(chk >= 0, chk)
check[p] = -chk
end
return tmp
end
local function dat_route(str, off)
local p = str:byte(off)
local b = base[p]
local greedy = -1
if not b then
return greedy
else
local chk = check[p]
if chk == -p then
greedy = off
elseif chk ~= p then
return greedy
end
end
for i = off+1, #str do
local c = str:byte(i)
local s = b + c
local t = check[s]
if t == -p then
greedy = i
elseif t ~= p then
break
end
p = s
b = base[p]
end
return greedy
end
local function dat_match(str, func)
local buf
local i = 1
local last = 1
while i <= #str do
local start = i
i = dat_route(str, i)
if i >= start then
if not buf then
buf = {}
end
print("sub:", str:sub(last, start - 1))
buf[#buf + 1] = str:sub(last, start - 1)
buf[#buf + 1] = func(str, start, i)
last = i + 1
else
i = start
end
i = i+1
end
if buf then
buf[#buf + 1] = str:sub(last)
return concat(buf)
else
return str
end
end
local function dat_detect(str, ban)
local i = 1
local count = 0
while i <= #str do
local start = i
i = dat_route(str, i)
if i >= start then
count = count + 1
ban[start] = i
else
i = start
end
i = i+1
end
return count
end
local function dat_build(tree)
--build base
for k, _ in pairs(tree) do
base[k] = 1
check[k] = k
end
--child
local cur = {}
local nxt = {}
merge(cur, tree)
while true do
for k, v in pairs(cur) do
assert(k > 0)
local c = dat_add(k, v)
merge(nxt, c)
cur[k] = nil
end
if next(nxt) == nil then
break
end
cur, nxt = nxt, cur
end
end
local function dump(buf, tbl)
local i = #buf
local format = string.format
for k, v in pairs(tbl) do
i = i + 1
buf[i] = format("[%04d]=%04d,", k, v)
if i % 8 == 0 then
i = i + 1
buf[i] = "\n"
end
end
end
local function dat_dump()
local buf = {
"local base = {"
}
local format = string.format
dump(buf, base)
buf[#buf + 1] = "}\nlocal check = {"
dump(buf, check)
buf[#buf + 1] =[[
}
local M = {
base = base,
check = check,
}
return M
]]
print(table.concat(buf, " "))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment