Skip to content

Instantly share code, notes, and snippets.

@akostylev0
Created November 14, 2014 11:16
Show Gist options
  • Save akostylev0/91ef447cbbaca2a69d4b to your computer and use it in GitHub Desktop.
Save akostylev0/91ef447cbbaca2a69d4b to your computer and use it in GitHub Desktop.
# coding=utf-8
from collections import defaultdict
import random
import time
class AhoCorasick(object):
class Node(object):
def __init__(self):
self.next = defaultdict(AhoCorasick.Node)
self.terms = []
self.failure = None
def __init__(self, terms = []):
self.root = AhoCorasick.Node()
for term in terms:
self.add_string(term)
self.make_failure_links()
def add_string(self, term = ""):
cur_node = self.root
for char in term:
cur_node = cur_node.next[char]
cur_node.terms.append(term)
def make_failure_links(self):
# DFS: depth-first search
def _make(parent):
for (char, child) in parent.next.items():
child.failure = (parent.failure.next.get(char) or
self.root.next.get(char) or
self.root)
if child.failure.terms:
child.terms.extend(child.failure.terms)
_make(child)
self.root.failure = self.root
for node in self.root.next.values():
node.failure = self.root
for node in self.root.next.values():
_make(node)
def match(self, query):
results = []
node = self.root
for i in xrange(len(query)):
char = query[i]
node = (node.next.get(char) or
node.failure.next.get(char) or
self.root)
for term in node.terms:
results.append((i + 1 - len(term), len(term)))
return results
def __repr__(self):
output = []
def _debug(output, char, node, depth=0):
output.append('%s[%s]%s' % (' '*depth, char, node.terms))
for (key, n) in node.next.items():
_debug(output, key, n, depth+1)
_debug(output, '', self.root)
return '\n'.join(output)
def gen_random_string(N):
return ''.join(random.choice(u"ЙЦУКЕНГШЩЗХЪФЫВАПРОЛДЖЭЯЧСМИТЬБЮQWERTYUIOPASDFGHJKLZXCVBNM1234567890") for _ in range(N))
r = []
for i in xrange(1, 400000):
r.append(gen_random_string(4))
patterns = r
s = gen_random_string(10000)
print 'genered'
root = AhoCorasick(patterns)
print '1'
t = time.time()
print (root.match(s))
print time.time() - t
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment