Skip to content

Instantly share code, notes, and snippets.

@brad-anton
Created December 14, 2018 16:04
Show Gist options
  • Save brad-anton/0255e06ba004f800dc710d72d9442f5e to your computer and use it in GitHub Desktop.
Save brad-anton/0255e06ba004f800dc710d72d9442f5e to your computer and use it in GitHub Desktop.
Trie Class
"""
trie.py
@brad_anton
Go back to college, kid!
"""
class Node(object):
def __init__(self, value, parent=None):
self.value = value
self.parent = parent
self.children = []
def __repr__(self):
parent = None
if self.parent is not None:
parent = self.parent.value
return "<Node value='{}', parent='{}', children_count='{}'>".format(self.value, parent, len(self.children))
class Trie(object):
INT = ':'
LEAF = ','
def __init__(self, verbose=False):
self.head = Node('head')
self.verbose = verbose
def add_single(self, single, head):
for child in head.children:
if child.value == single:
if self.verbose:
print 'Found existing node for {}, no need to create a new one'.format(single)
return child
if self.verbose:
print 'Creating Node: {}'.format(single)
n = Node(single, parent=head)
head.children.append(n)
return n
def add_word(self, word):
if self.verbose:
print 'Adding word: {}'.format(word)
head = self.head
for char in word:
head = self.add_single(char, head)
def get_single(self, value, head):
if not head.children:
return None
for child in head.children:
if child.value == value:
return child
return None
def find_word(self, word):
head = self.head
for char in word:
head = self.get_single(char, head)
if head is None:
return False
return True
def get_parts(self, domain):
# Need to do this because Guava includes dots in
# its serialize output :(
parts = []
d = domain[::-1]
p = d.split('.')
end = len(p)
for i in range(0, end):
part = p[i]
# For readability
one_part = ( end == 1 )
first_label = ( i == 0 and end > 1 )
interior = ( i > 0 and i < end - 1 )
if ( not one_part and first_label or interior ):
part = '{}.'.format(p[i])
parts.append(part)
return parts
def add_domain(self, domain):
if self.verbose:
print 'Adding domain: {}'.format(domain)
head = self.head
parts = self.get_parts(domain)
for part in parts:
head = self.add_single(part, head)
def add_domains(self, domains):
for d in domains:
self.add_domain(d)
def find_domain(self, domain):
head = self.head
parts = self.get_parts(domain)
for part in parts:
head = self.get_single(part, head)
if head is None:
return False
return True
def recurse(self, branch, node):
branch += node.value
if not node.children:
return branch + self.LEAF
branch += self.INT
for child in node.children:
branch = self.recurse(branch, child)
return branch + self.LEAF
def serialize(self):
head = self.head
branches = []
for h_child in self.head.children:
branch = self.recurse('', h_child)
# Branches end with two LEAFs
branch += self.LEAF
branches.append(branch)
return ''.join(branches)
@staticmethod
def process_serialized(stack, encoded, verbose=False):
t = Trie()
encodedLen = len(encoded)
if verbose:
print 'encodedLen: {}'.format(encodedLen)
print 'Stack: {}'.format(stack)
c = '\0'
idx = 0
for idx in range(0, encodedLen):
c = encoded[idx]
# Read all chars up until we encounter a control character
if c == '&' or c == '?' or c == '!' or c == ':' or c == ',':
if verbose:
print 'Got control char: {}'.format(c)
break;
# Add all characters up to the control character onto the stack
stack.append(encoded[0:idx])
if verbose:
print 'Stack (after append): {}'.format(stack)
if c == '?' or c == '!' or c == ':' or c == ',':
domain = ''.join(stack)
if verbose:
print 'Candidate: {}'.format(domain)
if len(domain) > 0:
print 'Adding domain: {}'.format(domain[::-1])
if verbose:
print 'Incrementing idx (1): {}, encodedLen: {}, encoded: {}'.format(idx, encodedLen, encoded[idx:])
# Continue past control character
idx += 1
# Process interior nodes
if c != '?' and c != ',':
while idx < encodedLen:
idx += Trie.process_serialized(stack, encoded[idx:])
if encoded[idx] == '?' or encoded[idx] == ',':
# End of branch?
if verbose:
print 'Incrementing idx (2): {}, encodedLen: {}, encoded: {}'.format(idx, encodedLen, encoded[idx:])
idx += 1
break;
stack.pop()
if verbose:
print 'Stack (end): {}'.format(stack)
return idx;
@staticmethod
def deserialize(encoded):
encodedLen = len(encoded)
idx = 0
while idx < encodedLen:
idx += Trie.process_serialized([], encoded[idx:] )
# TODO: Return the Trie :)
return None
if __name__ == '__main__':
domains = ['test.com', 'test2.com', 'test.org', 'test.test.org' ]
t = Trie(verbose=True)
t.add_domains(domains)
s = t.serialize()
print domains
print s
Trie.deserialize(s)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment