Skip to content

Instantly share code, notes, and snippets.

@sushlala
Created September 24, 2017 04:28
Show Gist options
  • Save sushlala/8eaeedd167aa877f8f8a38a1e94a7eeb to your computer and use it in GitHub Desktop.
Save sushlala/8eaeedd167aa877f8f8a38a1e94a7eeb to your computer and use it in GitHub Desktop.
a trie implemented in Python
from itertools import islice
from collections import defaultdict
from Queue import Queue
class Node(object):
'A node in our trie. Only nodes at which a key end will have a `val` attribute'
__slots__ = ('children', 'val')
def __init__(self):
self.children = defaultdict(self.__class__)
class Trie(object):
'Create a trie that may accessed like a dict'
def __init__(self):
self.root = Node()
def __setitem__(self, key, val):
'Traverse the trie, insert `val` at `key`'
keyIterable = key
curNode = self.root
# itetrate up till the second last keyIterable member
for key in islice(keyIterable, len(keyIterable) -1):
# if key not in children, children (defaultdict) insert a Node()
# at children[key] and return that
curNode = curNode.children[key]
# at the terminalNode, add the value
curNode = curNode.children[keyIterable[-1]]
curNode.val = val
def __getitem__(self, key):
'Traverse the trie, get `val` stored at `key`'
keyIterable = key
curNode = self.root
for key in islice(keyIterable, len(keyIterable) -1):
if key not in curNode.children:
raise KeyError
curNode = curNode.children[key]
key = keyIterable[-1]
if key not in curNode.children:
raise KeyError
lastNode = curNode.children[key]
if not hasattr(lastNode, 'val'): # this node was not a terminal Node.
raise KeyError
return lastNode.val
def __contains__(self, key):
'Traverse the trie, find out if `key` is present in it'
try:
self[key]
except KeyError:
return False
return True
def key_starts_with(self, prefix):
' return keys in the trie that start with the prefix'
# get to node that represents the last element in prefix
curNode = self.root
for key in prefix:
if key not in curNode.children:
raise StopIteration # we found no keys with that prefix
curNode = curNode.children[key]
# perform BFS from curNode: curNode represents end of prefix
q = Queue()
q.put((prefix, curNode))
while not q.empty():
key, node = q.get()
for letter, childNode in node.children.iteritems():
q.put((key+letter, childNode))
if hasattr(node, 'val'):
yield key
if __name__ == '__main__':
a = Trie()
# lets add some items to our trie
for k,v in [ ('abc', 1), ('abcdef', 2), ('abcdefgh', 3), ('zzz', 6) ]:
a[k] = v
print 'does the trie contain key `ppp`?', 'ppp' in a
print 'does the trie contain key `abc`?', 'abc' in a
print 'value of trie key `zzz`', a['zzz']
print 'keys that start with `abcde` in trie', list(a.key_starts_with('abcde'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment