Skip to content

Instantly share code, notes, and snippets.

@IvanaGyro
Created June 8, 2019 09:30
Show Gist options
  • Save IvanaGyro/e4f44f43a54ad64ae25e488fc3e34c9b to your computer and use it in GitHub Desktop.
Save IvanaGyro/e4f44f43a54ad64ae25e488fc3e34c9b to your computer and use it in GitHub Desktop.
Implement of Ukkonen’s algorithm of building suffix trees with Python
from collections import defaultdict
class Node: # for the effectivity reason, do not inherit from ABCs
__slots__ = ('beg', 'end', 'link', 's', 'node')
def __init__(self, s):
self.s = s
self.beg = self.end = self.link = None
self.node = defaultdict(lambda: Node(s))
def __getitem__(self, key):
return self.node[key]
def __setitem__(self, key, val):
self.node[key] = val
def __delitem__(self, key):
del self.node[key]
def __contains__(self, key):
return key in self.node
def __len__(self):
return len(self.node)
def __iter__(self):
return iter(self.node)
def __repr__(self):
content = ', '.join(f"'{n.s[n.beg:n.end]}': {repr(n)}" for n in self.node.values())
return f'{{{content}}}' if content else "'$'"
def range(self, tail):
end = self.end if self.end is not None else tail
return end - self.beg
def build_suffix_tree(s):
s += '$'
root = Node(s)
root.beg, root.end = 0, 0
active_node, active_edge, active_len = root, '\0', 0
remainder = 0
for i in range(len(s)):
prev = None
tail = i + 1
remainder += 1
cur = s[i]
while remainder:
while active_len:
edge = active_node[active_edge]
edge_len = edge.range(tail)
if active_len >= edge_len:
active_len -= edge_len
active_node = edge
active_edge = s[i - active_len] if active_len else '\0'
else:
break
if not active_len:
edge = active_node
if cur in edge:
# active_edge is always '\0' when active_len == 0
active_edge = cur
active_len += 1
if prev:
prev.link = edge
prev = edge
break
else:
edge[cur].beg = i
if prev:
prev.link = edge
prev = edge
remainder -= 1
if active_node == root:
assert remainder == 0
else: # if active_node != root then active_node must has link ?????
active_node = active_node.link
else:
edge = active_node[active_edge]
c = s[edge.beg + active_len]
if c == cur:
active_len += 1
if prev:
prev.link = edge
prev = edge
break
else:
# split the edge
if len(edge.node):
edge, edge.beg, edge.end, edge[c] = Node(s), edge.beg, edge.end, edge
active_node[active_edge] = edge
edge[c].beg, edge[c].end = edge.beg + active_len, edge.end
edge[cur].beg = i
edge.end = edge.beg + active_len
remainder -= 1
if prev:
prev.link = edge
prev = edge
if active_node == root:
active_len -= 1
active_edge = s[i-active_len] if active_len else '\0'
else:
active_node = active_node.link
return root
def build_suffix_tree_dict(s):
'''
sacrifice readability for performance
'''
s += '$'
tail = [0]
def node():
n = defaultdict(node)
n['attr'] = [None, tail, None] # beg, end, link
return n
root = node()
attr = root['attr']
attr[0], attr[1] = 0, [0]
active_node, active_edge, active_len = root, '\0', 0
remainder = 0
prev_tmplt = [0]*3
for i in range(len(s)):
prev = prev_tmplt
tail[0] = i + 1
remainder += 1
cur = s[i]
while remainder:
while active_len:
edge = active_node[active_edge]
beg, end, _ = edge['attr']
edge_len = end[0] - beg
if active_len >= edge_len:
active_len -= edge_len
active_node = edge
active_edge = s[i - active_len] if active_len else '\0'
else:
break
if not active_len:
edge = active_node
edge_attr = edge['attr']
if cur in edge:
# active_edge is always '\0' when active_len == 0
active_edge = cur
active_len += 1
# create link
prev[2] = edge
prev = edge_attr
break
else:
edge[cur]['attr'][0] = i
remainder -= 1
# create link
prev[2] = edge
prev = edge_attr
if active_node == root:
assert remainder == 0
else: # if active_node != root then active_node must has link ?????
active_node = edge_attr[2]
else:
edge = active_node[active_edge]
edge_attr = edge['attr']
c = s[edge_attr[0] + active_len]
if c == cur:
active_len += 1
# create link
prev[2] = edge
prev = edge_attr
break
else:
# split the edge
edge, edge['attr'], edge[c] = node(), edge_attr.copy(), edge
active_node[active_edge], edge_attr = edge, edge['attr']
edge[c]['attr'][0], edge[c]['attr'][1] = edge_attr[0] + active_len, edge_attr[1]
edge[cur]['attr'][0] = i
edge_attr[1] = [edge_attr[0] + active_len]
remainder -= 1
# create link
prev[2] = edge
prev = edge_attr
if active_node == root:
active_len -= 1
active_edge = s[i-active_len] if active_len else '\0'
else:
active_node = active_node['attr'][2]
return root
cases = [
'a',
'aa',
'ab',
'abcabx',
'abcabdeabdx',
'abcabdab',
'cdddcdc',
'cdddcddc',
'cdddcdddc',
'aaaaaaaa',
'aabbaabbaa',
'aabaacaad',
'aabaacaadaae',
'bbabbdabbba',
'bbabbdabbbaaa',
]
import random
import string
for _ in range(10):
l = random.randint(500, 1500)
s = ''
for __ in range(l):
s += string.ascii_lowercase[random.randint(0, 2)]
cases.append(s)
def traversal(root):
res = []
def helper(path, root):
val = root.s[root.beg:root.end] if root.end is not None else root.s[root.beg:]
if not len(root):
res.append(path + val)
else:
for key in root:
helper(path+val, root[key])
helper('', root)
return sorted(res, key=lambda x: len(x))
def traversal_dict(root, s):
res = []
def helper(path, root):
attr = root['attr']
beg, end = attr[0], attr[1][0]
val = s[beg:end]
if len(root) == 1:
# res.append(path + root.s[root.beg:root.end])
res.append(path + val)
else:
for key in root:
if len(key) > 1:
continue
# helper(path+root.s[root.beg:root.end], root[key])
helper(path+val, root[key])
helper('', root)
return sorted(res, key=lambda x: len(x))
for case in cases:
root = build_suffix_tree(case)
res = traversal(root)
case += '$'
ans = [case[i:] for i in range(len(case)-1, -1, -1)]
assert res == ans, (root, res)
for case in cases:
root = build_suffix_tree_dict(case)
case += '$'
res = traversal_dict(root, case)
ans = [case[i:] for i in range(len(case)-1, -1, -1)]
assert res == ans, (root, res)
def test():
for case in cases:
root = build_suffix_tree(case)
def test_dict():
for case in cases:
root = build_suffix_tree_dict(case)
if __name__ == '__main__':
from timeit import timeit
print(timeit('test_dict()', 'from suffix_tree import test_dict', number=10))
print(timeit('test()', 'from suffix_tree import test', number=10))
print('-------------')
print(timeit('test_dict()', 'from suffix_tree import test_dict', number=10))
print(timeit('test()', 'from suffix_tree import test', number=10))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment