Last active
January 18, 2020 14:22
-
-
Save tjkendev/99d7330fe5642004b68268b31ba38ad4 to your computer and use it in GitHub Desktop.
pythonで実装したSA-IS (線形のSuffix Array構築アルゴリズム)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf-8 | |
from collections import Counter | |
def SAIS(lst, num): | |
l = len(lst) | |
if l<2: return lst+[0] | |
lst = lst + [0] | |
l += 1 | |
res = [None] * l | |
# L-type(t[i] = 0) or S-type(t[i] = 1) | |
# s{i} < s{i+1} --> iはS-type | |
# s{i} > s{i+1} --> iはL-type | |
# s{i} = s{i+1} --> iはi+1のtypeと同じ | |
# s{l-1}(末尾)はS-typeとする | |
t = [1] * l | |
for i in xrange(l-2, -1, -1): | |
t[i] = 1 if lst[i]<lst[i+1] or (lst[i]==lst[i+1] and t[i+1]) else 0 | |
# LMSを求める | |
isLMS = [t[i-1]<t[i] for i in xrange(l)] | |
LMS = [i for i in xrange(1, l) if t[i-1]<t[i]] | |
LMSn = len(LMS) | |
# bucketにおいて、各値が入る位置を計算 | |
cbase = Counter(lst) | |
count = cbase.copy() | |
tmp = 0 | |
cstart = [0]*(num+1); cend = [0]*(num+1) | |
for key in xrange(num+1): | |
cstart[key] = tmp | |
count[key] += tmp | |
cend[key] = tmp = count[key] | |
# まずはLMS-substringのソートを行う | |
# (正しいSuffix Arrayを作っているわけではない(LMSの順番が不完全なため作れない)) | |
# LMSを最初の文字でbucket sort --> 適当な順番でLMSを仮配置してinduced sort | |
# --> LMS間の順番が確定できる | |
# LMSの仮配置 | |
for e in reversed(LMS): | |
count[lst[e]] -= 1 | |
res[count[lst[e]]] = e | |
# 上からSuffix Siを検索 --> S{i-1}がL-typeならバケットの一番上の空きに挿入 | |
# (この時S{i-1}で埋めたものも検索対象とする) | |
for e in res: | |
if e>0 and not t[e-1]: | |
res[cstart[lst[e-1]]] = e-1 | |
cstart[lst[e-1]] += 1 | |
# 下からSiffux Siを検索 --> S{i-1}がS-typeならバケットの一番下の空きに挿入 | |
# (先に配置したLMSのS-typeはなかったものとして扱う) | |
for e in reversed(res): | |
if e>0 and t[e-1]: | |
cend[lst[e-1]] -= 1 | |
res[cend[lst[e-1]]] = e-1 | |
# ここまででLMS-substringが小さい順にソートされる | |
# ここで、完全に一致する(重複する)LMS-substringの存在をチェックする | |
name = 0; prev = -1 | |
pLMS = {} | |
# LMSをソートした後の順に見ていく | |
# 文字比較はN回を超えないのでO(N)っぽい? | |
for e in res: | |
if isLMS[e]: | |
for i in xrange(l): | |
if prev==-1 or lst[e+i]!=lst[prev+i]: | |
name += 1; prev = e | |
break | |
elif i and (isLMS[e+i] or isLMS[prev+i]): break | |
pLMS[e] = name-1 | |
if name < LMSn: | |
# 名前付けで一つでも重複がある場合はSA-ISで再帰的に計算 | |
# (例えば、LMS-substringに2つの文字列"ACA"を含むと、文字列として等しいため順番を付けれない) | |
# LMSに対し、元の文字列の順番で上で求めた番号付けをする | |
sublst = [pLMS[e] for e in LMS if e<l-1] | |
ret = SAIS(sublst, name-1) | |
# 帰ってきたSAの順番で配置するようにLMSの内容を変更 | |
# (仮配置時に下から突っ込んでいくため、逆順) | |
LMS = [LMS[i] for i in reversed(ret)] | |
else: | |
# 重複がない場合、ソートが完了しているため、並んでいる順番になるようにLMSの順序変更 | |
LMS = [e for e in reversed(res) if isLMS[e]] | |
# ここから正しいSuffix Arrayを作っていく | |
# LMSを正しく配置する以外、最初と同じ流れでinduced sortする | |
res = [None]*l | |
count = cbase | |
tmp = 0 | |
for key in xrange(num+1): | |
cstart[key] = tmp | |
count[key] += tmp | |
cend[key] = tmp = count[key] | |
for e in LMS: | |
count[lst[e]] -= 1 | |
res[count[lst[e]]] = e | |
for e in res: | |
if e>0 and not t[e-1]: | |
res[cstart[lst[e-1]]] = e-1 | |
cstart[lst[e-1]] += 1 | |
for e in reversed(res): | |
if e>0 and t[e-1]: | |
cend[lst[e-1]] -= 1 | |
res[cend[lst[e-1]]] = e-1 | |
return res | |
# bucket-sortにおいて、確保する配列サイズを小さくするため、 | |
# 各値から{0,1,2,..,n}に圧縮 | |
def chr_compression(s): | |
uniq = list(set(s)) | |
uniq.sort() | |
return map({e: i+1 for i, e in enumerate(uniq)}.__getitem__, s), len(uniq) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf-8 | |
# 実装確認テスト | |
# 単純にソートでSuffix Arrayを求める | |
def check_solver(s): | |
s0 = s + "$" | |
return sorted(range(len(s0)), key=lambda x: s0[x:]) | |
import random | |
random.seed() | |
for i in xrange(10000): | |
# random case | |
s = "".join(chr(random.randint(ord('A'), ord('Z'))) for i in xrange(100)) | |
print "Case %d : %s" % (i+1, s) | |
res = SAIS(*chr_compression(s)) | |
ans = check_solver(s) | |
# check | |
if res != ans: | |
print "result: Wrong" | |
print res, ans | |
print "\n".join(str(i) + " " + (s+"$")[i:] for i in res) | |
print "Wrong ", ["%d(pos %d)"%(res[i], i) for i in xrange(len(s)+1) if res[i] != ans[i]] | |
exit(0) | |
print "result: OK" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment