Last active
September 27, 2019 01:29
-
-
Save findstr/2ddd1f434fb752e8ed65c9cb79b2f6c4 to your computer and use it in GitHub Desktop.
Trie for lua
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local concat = table.concat | |
local tree = {} | |
local function tree_add(str) | |
local p = tree | |
for i = 1, #str do | |
local c = str:byte(i) | |
local n = p[c] | |
if not n then | |
n = {} | |
p[c] = n | |
end | |
p = n | |
end | |
p.isleaf = true | |
end | |
local function tree_route(str, i) | |
local p = tree | |
local greedy = -1 | |
while i <= #str do | |
local c = str:byte(i) | |
p = p[c] | |
if not p then | |
break | |
elseif p.isleaf then | |
greedy = i | |
end | |
i = i + 1 | |
end | |
return greedy | |
end | |
local function tree_match(str, func) | |
local buf | |
local i = 1 | |
local last = 1 | |
while i <= #str do | |
local start = i | |
i = tree_route(str, i) | |
if i >= start then | |
if not buf then | |
buf = {} | |
end | |
print("sub:", str:sub(last, start - 1)) | |
buf[#buf + 1] = str:sub(last, start - 1) | |
buf[#buf + 1] = func(str, start, i) | |
last = i + 1 | |
else | |
i = start | |
end | |
i = i+1 | |
end | |
if buf then | |
buf[#buf + 1] = str:sub(last) | |
return concat(buf) | |
else | |
return str | |
end | |
end | |
local function tree_detect(str, ban) | |
local i = 1 | |
local count = 0 | |
while i <= #str do | |
local start = i | |
i = tree_route(str, i) | |
if i >= start then | |
count = count + 1 | |
ban[start] = i | |
else | |
i = start | |
end | |
i = i+1 | |
end | |
return count | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local base = {} | |
local check = {} | |
local check_mt = {__index = false} | |
local base_mt = {__index = false} | |
local function merge(dst, src) | |
for k, v in pairs(src) do | |
dst[k] = v | |
end | |
end | |
local function find_empty(check, n) | |
while check[n] do | |
n = n + 1 | |
end | |
return n | |
end | |
local function dat_try_add(p, c) | |
--make temp buf | |
local base_tmp = {} | |
local check_tmp = {} | |
base_mt.__index = base_tmp | |
check_mt.__index = check_tmp | |
setmetatable(base, base_mt) | |
setmetatable(base, check_mt) | |
--begin | |
local b = base[p] | |
for ck, cc in pairs(c) do | |
local s = b + ck | |
local ss = find_empty(check, s) | |
if ss > s then | |
return ss - s | |
end | |
ss = find_empty(base, ss) | |
if ss > s then | |
return ss - s | |
end | |
check_tmp[s] = p | |
base_tmp[s] = 1 | |
end | |
merge(base, base_tmp) | |
merge(check, check_tmp) | |
end | |
local function dat_add(p, c) | |
local tmp = {} | |
local isleaf = c.isleaf | |
c.isleaf = nil | |
local b = base[p] | |
while true do | |
local off = dat_try_add(p, c) | |
if not off then | |
break | |
end | |
b = b + off | |
base[p] = b | |
end | |
for k, v in pairs(c) do | |
assert(k > 0) | |
tmp[b+k] = v | |
end | |
c.isleaf = isleaf | |
if isleaf then | |
local chk = check[p] | |
assert(chk >= 0, chk) | |
check[p] = -chk | |
end | |
return tmp | |
end | |
local function dat_route(str, off) | |
local p = str:byte(off) | |
local b = base[p] | |
local greedy = -1 | |
if not b then | |
return greedy | |
else | |
local chk = check[p] | |
if chk == -p then | |
greedy = off | |
elseif chk ~= p then | |
return greedy | |
end | |
end | |
for i = off+1, #str do | |
local c = str:byte(i) | |
local s = b + c | |
local t = check[s] | |
if t == -p then | |
greedy = i | |
elseif t ~= p then | |
break | |
end | |
p = s | |
b = base[p] | |
end | |
return greedy | |
end | |
local function dat_match(str, func) | |
local buf | |
local i = 1 | |
local last = 1 | |
while i <= #str do | |
local start = i | |
i = dat_route(str, i) | |
if i >= start then | |
if not buf then | |
buf = {} | |
end | |
print("sub:", str:sub(last, start - 1)) | |
buf[#buf + 1] = str:sub(last, start - 1) | |
buf[#buf + 1] = func(str, start, i) | |
last = i + 1 | |
else | |
i = start | |
end | |
i = i+1 | |
end | |
if buf then | |
buf[#buf + 1] = str:sub(last) | |
return concat(buf) | |
else | |
return str | |
end | |
end | |
local function dat_detect(str, ban) | |
local i = 1 | |
local count = 0 | |
while i <= #str do | |
local start = i | |
i = dat_route(str, i) | |
if i >= start then | |
count = count + 1 | |
ban[start] = i | |
else | |
i = start | |
end | |
i = i+1 | |
end | |
return count | |
end | |
local function dat_build(tree) | |
--build base | |
for k, _ in pairs(tree) do | |
base[k] = 1 | |
check[k] = k | |
end | |
--child | |
local cur = {} | |
local nxt = {} | |
merge(cur, tree) | |
while true do | |
for k, v in pairs(cur) do | |
assert(k > 0) | |
local c = dat_add(k, v) | |
merge(nxt, c) | |
cur[k] = nil | |
end | |
if next(nxt) == nil then | |
break | |
end | |
cur, nxt = nxt, cur | |
end | |
end | |
local function dump(buf, tbl) | |
local i = #buf | |
local format = string.format | |
for k, v in pairs(tbl) do | |
i = i + 1 | |
buf[i] = format("[%04d]=%04d,", k, v) | |
if i % 8 == 0 then | |
i = i + 1 | |
buf[i] = "\n" | |
end | |
end | |
end | |
local function dat_dump() | |
local buf = { | |
"local base = {" | |
} | |
local format = string.format | |
dump(buf, base) | |
buf[#buf + 1] = "}\nlocal check = {" | |
dump(buf, check) | |
buf[#buf + 1] =[[ | |
} | |
local M = { | |
base = base, | |
check = check, | |
} | |
return M | |
]] | |
print(table.concat(buf, " ")) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment