Skip to content

Instantly share code, notes, and snippets.

@tyler
Created May 8, 2011 22:43
Show Gist options
  • Save tyler/961773 to your computer and use it in GitHub Desktop.
Save tyler/961773 to your computer and use it in GitHub Desktop.
serializing trie
from struct import *
def _node_size(child_count):
return child_count * 5 + 6
class Node(object):
def __init__(self):
self.children = {}
self.terminal = False
self.value = None
@staticmethod
def deserialize(serialized_trie):
return Node._deserialize_node(serialized_trie, 0)
@staticmethod
def _deserialize_node(data, offset):
node = Node()
size, value = unpack('=HI', data[offset : offset + 6])
node.terminal = True if value != 0 else False
node.value = None if value == 0 else value
child_count = (size - 6) / 5
for child_idx in range(child_count):
child_start = offset + 6 + (child_idx * 5)
k, child_offset = unpack('=BI', data[child_start : child_start + 5])
child_node = Node._deserialize_node(serialized_trie, child_offset)
node.children[k] = child_node
return node
def serialize(self):
total_length, output = self._serialize_node(0)
return output
def _serialize_node(self, starting_offset):
nodesize = _node_size(len(self.children))
value = self.value or 0
output = pack('=HI', nodesize, value)
next_child_offset = starting_offset + nodesize
serialized_children = ''
for k in self.children:
output += pack('=BI', k, next_child_offset)
child = self.children[k]
next_child_offset, serialized_child = child._serialize_node(next_child_offset)
serialized_children += serialized_child
return next_child_offset, output + serialized_children
def __setitem__(self, key, value):
if type(value) != int:
raise TypeError, 'Expected int as value'
if type(key) == str:
self._add(bytearray(key), value)
elif type(key) == unicode:
self._add(bytearray(key, 'UTF-8'), value) # assumes UTF-8
elif type(key) == bytearray:
self._add(key, value)
else:
raise TypeError, 'Expected string, unicode, or bytearray as key'
def __getitem__(self, key):
value = self._retrieve(key)
if value == None:
raise IndexError
else:
return value
def __contains__(self, key):
if self._retrieve(key) == None:
return False
else:
return True
def _add(self, key, value):
if len(key) > 0:
byte = key[0]
rest = key[1:]
if byte not in self.children:
self.children[byte] = Node()
child = self.children[byte]
child._add(rest, value)
else:
self.terminal = True
self.value = value
def _retrieve(self, key):
if len(key) > 0:
byte = ord(key[0])
rest = key[1:]
if byte in self.children:
child = self.children[byte]
return child._retrieve(rest)
else:
return None
else:
if self.terminal:
return self.value
else:
return None
if __name__ == '__main__':
trie = Node()
trie['monkey'] = 1
trie['monk'] = 2
trie['monkeys'] = 3
trie['foobar'] = 4
print trie['monkey']
print trie['monk']
print trie['monkeys']
print trie['foobar']
#trie = Node()
#test_dictionary = open('test_dict_small', 'r')
#for line in test_dictionary:
# word, count = line.split("\t")
# trie[word] = int(count)
print "-- Serializing..."
serialized_trie = trie.serialize()
print "-- Deserializing..."
trie = Node.deserialize(serialized_trie)
print trie['monkey']
print trie['monk']
print trie['monkeys']
print trie['foobar']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment