Last active
August 29, 2015 14:08
-
-
Save zsrinivas/96e84408ce00bc49245c to your computer and use it in GitHub Desktop.
binary search tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BST(object): | |
def __init__(self): | |
self.root = None | |
def insert(self, t): | |
""" | |
Insert key t into this BST, modifying it in-place. | |
""" | |
new = BSTnode(t) | |
if self.root is None: | |
self.root = new | |
return new | |
# else: | |
node = self.root | |
while True: | |
if t < node.key: | |
# Go left | |
if node.left is None: | |
node.left = new | |
new.parent = node | |
break | |
node = node.left | |
else: | |
# Go right | |
if node.right is None: | |
node.right = new | |
new.parent = node | |
break | |
node = node.right | |
return new | |
def find(self, t): | |
""" | |
Return the node for key t if is in the tree, or None otherwise. | |
""" | |
node = self.root | |
while node is not None: | |
if t == node.key: | |
return node | |
elif t < node.key: | |
node = node.left | |
else: | |
node = node.right | |
return None | |
def delete(self, t): | |
# TODO: test this thing | |
node = self.find(t) | |
if node is None: | |
return None, None | |
if node.left is not None: | |
maximum, parent = self.delete_max(node.left) | |
elif node.right is not None: | |
maximum, parent = self.delete_min(node.right) | |
else: | |
if node.parent is not None: | |
if node.parent.left is node: | |
node.parent.left = None | |
else: | |
node.parent.right = None | |
else: | |
self.root = None | |
node.disconnect() | |
return node, None | |
maximum.parent = node.parent | |
maximum.left = node.left | |
maximum.right = node.right | |
if maximum.left is not None: | |
maximum.left.parent = maximum | |
if maximum.right is not None: | |
maximum.right.parent = maximum | |
if maximum.parent is not None: | |
if maximum.parent.left is node: | |
maximum.parent.left = maximum | |
else: | |
maximum.parent.right = maximum | |
if node is self.root: | |
self.root = maximum | |
parent = node.parent | |
node.disconnect() | |
return node, parent | |
def minimum(self): | |
# TODO: test this thing | |
if self.root is None: | |
return None | |
node = self.root | |
while node.left is not None: | |
node = node.left | |
return node | |
def maximum(self): | |
# TODO: test this thing | |
if self.root is None: | |
return None | |
node = self.root | |
while node.right is not None: | |
node = node.right | |
return node | |
def delete_max(self, root=None): | |
root = root or self.root | |
if root is None: | |
return None, None | |
else: | |
maximum = root | |
while maximum.right is not None: | |
maximum = maximum.right | |
if maximum is not root: | |
maximum.parent.right = maximum.left | |
elif maximum.parent is not None: | |
if maximum.parent.left is maximum: | |
maximum.parent.left = maximum.left | |
else: | |
maximum.parent.right = maximum.right | |
else: | |
self.root = maximum.left | |
if maximum.left is not None: | |
maximum.left.parent = maximum.parent | |
parent = maximum.parent | |
maximum.disconnect() | |
return maximum, parent | |
def delete_min(self, root=None): | |
""" | |
Delete the minimum key (and return the old node containing it). | |
""" | |
root = root or self.root | |
if root is None: | |
return None, None | |
else: | |
# Walk to leftmost node. | |
node = root | |
while node.left is not None: | |
node = node.left | |
# Remove that node and promote its right subtree. | |
if node is not root: | |
node.parent.left = node.right | |
elif node.parent is not None: # The root was smallest. | |
if node.parent.left is node: | |
node.parent.left = node.right | |
else: | |
node.parent.right = node.right | |
else: | |
self.root = node.right | |
if node.right is not None: | |
node.right.parent = node.parent | |
parent = node.parent | |
node.disconnect() | |
return node, parent | |
def __str__(self): | |
if self.root is None: | |
return '<empty tree>' | |
def recurse(node): | |
if node is None: | |
return [], 0, 0 | |
label = str(node.key) | |
left_lines, left_pos, left_width = recurse(node.left) | |
right_lines, right_pos, right_width = recurse(node.right) | |
middle = max(right_pos + left_width - left_pos + 1, len(label), 2) | |
pos = left_pos + middle // 2 | |
width = left_pos + middle + right_width - right_pos | |
while len(left_lines) < len(right_lines): | |
left_lines.append(' ' * left_width) | |
while len(right_lines) < len(left_lines): | |
right_lines.append(' ' * right_width) | |
if (middle - len(label)) % 2 == 1 and node.parent is not None and \ | |
node is node.parent.left and len(label) < middle: | |
label += '.' | |
label = label.center(middle, '.') | |
if label[0] == '.': | |
label = ' ' + label[1:] | |
if label[-1] == '.': | |
label = label[:-1] + ' ' | |
lines = [ | |
' ' * left_pos + label + ' ' * (right_width - right_pos), | |
' ' * left_pos + '/' + ' ' * (middle - 2) + | |
'\\' + ' ' * (right_width - right_pos)] + \ | |
[ | |
left_line + ' ' * | |
(width - left_width - right_width) + | |
right_line | |
for left_line, right_line in zip(left_lines, right_lines) | |
] | |
return lines, pos, width | |
return '\n'.join(recurse(self.root)[0]) | |
def serialize(self): | |
r = self.root | |
if not self.root: | |
return [] | |
stack = [r] | |
result = [] | |
while stack: | |
node = stack.pop() | |
result.append(node) | |
if not node.left: | |
stack.append(node.left) | |
if not node.right: | |
stack.append(node.right) | |
return '|'.join([str(x) for x in result]) | |
def deserialize(self, data): | |
data = map(int, data.split('|')) | |
for x in data: | |
self.insert(x) | |
class BSTnode(object): | |
""" | |
Representation of a node in a binary search tree. | |
Has a left child, right child, and key value. | |
""" | |
def __init__(self, t): | |
""" | |
Create a new leaf with key t. | |
""" | |
self.key = t | |
self.disconnect() | |
def disconnect(self): | |
self.left = None | |
self.right = None | |
self.parent = None | |
def has_left(self): | |
return self.left | |
def has_right(self): | |
return self.right | |
def is_left(self): | |
return self.parent and self.parent.left is self | |
def is_right(self): | |
return self.parent and self.parent.right is self | |
def is_root(self): | |
return not self.parent | |
def has_anychild(self): | |
return self.left or self.right | |
def has_allchild(self): | |
return self.left and self.right | |
def test(args=None, BSTtype=BST): | |
import random | |
import sys | |
if not args: | |
args = sys.argv[1:] | |
if not args: | |
print 'usage: %s <number-of-random-items | item item item ...>' % \ | |
sys.argv[0] | |
sys.exit() | |
elif len(args) == 1: | |
items = [random.randrange(100) for i in xrange(int(args[0]))] | |
else: | |
items = [int(i) for i in args] | |
tree = BSTtype() | |
print tree | |
for item in items: | |
tree.insert(item) | |
print tree | |
print '-------------------------------------------' | |
for item in items: | |
tree.delete(item) | |
print tree | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment