Skip to content

Instantly share code, notes, and snippets.

@andlima
Last active August 29, 2015 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andlima/e823ba48f9a03743f36c to your computer and use it in GitHub Desktop.
Save andlima/e823ba48f9a03743f36c to your computer and use it in GitHub Desktop.
from __future__ import print_function
class Trie(object):
'''Implementation of a trie.'''
def __init__(self, collection=None):
self.ends = False
self.children = {}
if collection is not None:
for s in collection:
self.add(s)
def add(self, s):
return self._add_last_chars(s, 0)
def remove(self, s):
return self._remove_last_chars(s, 0)
def clear(self):
self.ends = False
self.children = {}
def __contains__(self, s):
return self._contains_last_chars(s, 0)
def __iter__(self):
for s in self._iter_with_prefix(''):
yield s
def __len__(self):
return self.ends + sum(len(child) for child in self.children.values())
def _add_last_chars(self, s, i):
if len(s) == i:
if self.ends:
return False
self.ends = True
return True
c = s[i]
if c not in self.children:
self.children[c] = Trie()
return self.children[c]._add_last_chars(s, i+1)
def _remove_last_chars(self, s, i):
if len(s) == i:
if not self.ends:
return False
self.ends = False
return True
c = s[i]
if c not in self.children:
return False
return self.children[c]._remove_last_chars(s, i+1)
def _contains_last_chars(self, s, i):
if len(s) == i:
return self.ends
trie = self.children.get(s[i])
if trie is None:
return False
return trie._contains_last_chars(s, i+1)
def _iter_with_prefix(self, prefix):
if self.ends:
yield prefix
for c, child in sorted(self.children.items()):
for s in child._iter_with_prefix(prefix + c):
yield s
if __name__ == '__main__':
trie = Trie()
assert len(trie) == 0
assert '' not in trie
assert 'a' not in trie
assert 'b' not in trie
assert 'abc' not in trie
assert trie.add('a')
assert len(trie) == 1
assert '' not in trie
assert 'a' in trie
assert 'b' not in trie
assert 'abc' not in trie
assert trie.add('')
assert len(trie) == 2
assert '' in trie
assert 'a' in trie
assert 'b' not in trie
assert 'abc' not in trie
assert trie.add('b')
assert len(trie) == 3
assert '' in trie
assert 'a' in trie
assert 'b' in trie
assert 'abc' not in trie
assert trie.add('abc')
assert len(trie) == 4
assert '' in trie
assert 'a' in trie
assert 'b' in trie
assert 'aa' not in trie
assert 'ab' not in trie
assert 'abc' in trie
assert not trie.add('')
assert not trie.add('abc')
assert not trie.add('b')
assert not trie.add('a')
assert len(trie) == 4
assert ['', 'a', 'abc', 'b'] == [s for s in trie]
assert trie.remove('a')
assert len(trie) == 3
assert ['', 'abc', 'b'] == [s for s in trie]
assert not trie.remove('a')
assert len(trie) == 3
assert ['', 'abc', 'b'] == [s for s in trie]
trie.clear()
assert len(trie) == 0
assert [] == [s for s in trie]
trie2 = Trie(['x', 'pq', 'xa'])
assert ['pq', 'x', 'xa'] == [s for s in trie2]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment