Skip to content

Instantly share code, notes, and snippets.

@TheDataLeek
Created February 1, 2013 04:28
Show Gist options
  • Save TheDataLeek/4689217 to your computer and use it in GitHub Desktop.
Save TheDataLeek/4689217 to your computer and use it in GitHub Desktop.
Binary Search Tree implementation and test harness
#!/usr/bin/env
#
# binary_search_tree.py
#
import sys
class bt_node:
data = 0
left = None
right = None
class BinarySearchTree:
def __init__(self):
self.tree = None
self.tree_list = []
self.tree_nodes = []
def init_node(self, data):
"""
Create and return a bt_node object that has been initialized
with the given data and two None children.
"""
new_node = bt_node()
new_node.data = data
return new_node
def insert(self, new_node):
"""
Insert the new_node into the tree at the correct location.
"""
if self.tree is None:
self.tree = new_node
else:
cursor = self.tree
while True:
if new_node.data < cursor.data:
if cursor.left is None:
cursor.left = new_node
break
else:
cursor = cursor.left
if new_node.data >= cursor.data:
if cursor.right is None:
cursor.right = new_node
break
else:
cursor = cursor.right
def insert_data(self, data):
"""
Insert a new node that contains the given data into the tree
at the correct location.
"""
new_node = self.init_node(data)
self.insert(new_node)
def remove(self, data):
"""
Removes a node from the tree whose data value is the same as
the argument.
"""
if self.contains(data):
cursor = self.tree
parent = None
while True:
if cursor.data == data:
break
elif cursor.data > data:
parent = cursor
cursor = cursor.left
elif cursor.data < data:
parent = cursor
cursor = cursor.right
if cursor.right and cursor.left:
successor, successor_parent = self.find_successor(cursor)
predecessor, predecessor_parent = self.find_predecessor(cursor)
self.to_array()
if len(self.tree_list) % 2 == 0:
if successor_parent:
cursor.data = successor.data
self.delete_node(successor, successor_parent)
else:
cursor.data = successor.data
self.delete_node(successor, cursor)
else:
if predecessor_parent:
cursor.data = predecessor.data
self.delete_node(predecessor, predecessor_parent)
else:
cursor.data = predecessor.data
self.delete_node(predecessor, cursor)
else:
if parent:
self.delete_node(cursor, parent)
else:
try:
successor, successor_parent = self.find_successor(cursor)
except AttributeError:
predecessor, predecessor_parent = self.find_predecessor(cursor)
if successor:
cursor.data = successor.data
self.delete_node(successor, cursor)
else:
cursor.data = predecessor.data
self.delete_node(predecessor, cursor)
def delete_node(self, node, parent):
'''
Deletes specified node with specified parent
Should not be called
'''
try:
if parent.left == node:
if node.left:
parent.left = node.left
elif node.right:
parent.left = node.right
else:
parent.left = None
elif parent.right == node:
if node.left:
parent.right = node.left
elif node.right:
parent.right = node.right
else:
parent.right = None
else:
if parent.left == node:
parent.left = None
if parent.right == node:
parent.right = None
except AttributeError:
print sys.exc_info()
print ' ', self.to_array()
print ' ', node.data
sys.exit(1)
def find_successor(self, node):
'''
Finds the in-order successor of the given node
'''
successor_node = node.right
successor_top = None
while True:
if successor_node.left == None:
break
else:
successor_top = successor_node
successor_node = successor_node.left
return successor_node, successor_top
def find_predecessor(self, node):
'''
Finds the in-order predecessor of the given node
'''
predecessor_node = node.left
predecessor_top = None
while True:
if predecessor_node.right == None:
break
else:
predecessor_top = predecessor_node
predecessor_node = predecessor_node.right
return predecessor_node, predecessor_top
def contains(self, data):
"""
Return True or False depending on if this tree contains a node
with the supplied data.
"""
if self.tree is None:
return False
else:
cursor = self.tree
while True:
if data < cursor.data:
if cursor.data == data:
return True
if cursor.left is None:
return False
else:
cursor = cursor.left
if data >= cursor.data:
if cursor.data == data:
return True
if cursor.right is None:
return False
else:
cursor = cursor.right
def get_node(self, data):
"""
If the tree contains a node with the supplied data, return
it. Otherwise return None.
"""
if self.contains(data):
cursor = self.tree
while True:
if data < cursor.data:
cursor = cursor.left
if data >= cursor.data:
if cursor.data == data:
return cursor
else:
cursor = cursor.right
else:
return None
def size(self):
"""
Return the size of this tree. If it is empty this returns 0.
"""
if self.tree is None:
return 0
else:
return len(self.to_array())
def to_array(self):
"""
Create and fill a list with the data contained in this
tree. The elements of the returned list must be in the same
order as they are found during an inorder traversal, which
means the numbers should be in non-decreasing order.
"""
if self.tree is None:
return []
else:
self.tree_list = []
self.tree_nodes = []
self.search(self.tree)
return self.tree_list
def search(self, node):
if node != None:
self.tree_list.append(node.data)
self.tree_nodes.append(node)
self.search(node.left)
self.search(node.right)
#!/usr/bin/env python
import unittest
from binary_search_tree import BinarySearchTree
class TestBinarySearchTree(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
def setUp(self):
self.bst = BinarySearchTree()
self.node_5 = self.bst.init_node(-5)
self.node_4 = self.bst.init_node(-4)
self.node_3 = self.bst.init_node(-3)
self.node_2 = self.bst.init_node(-2)
self.node_1 = self.bst.init_node(-1)
self.node0 = self.bst.init_node(0)
self.node1 = self.bst.init_node(1)
self.node2 = self.bst.init_node(2)
self.node3 = self.bst.init_node(3)
self.node4 = self.bst.init_node(4)
self.node5 = self.bst.init_node(5)
self.mytree = BinarySearchTree()
self.mytree.insert_data(5)
self.mytree.insert_data(2)
self.mytree.insert_data(3)
self.mytree.insert_data(1)
self.mytree.insert_data(4)
self.mytree.insert_data(0)
self.mytree.insert_data(2.5)
self.mytree.insert_data(8)
self.mytree.insert_data(9)
self.mytree.insert_data(6)
self.mytree.insert_data(4.5)
def test_init_node(self):
node1 = self.bst.init_node(5)
node2 = self.bst.init_node(4)
assert(node1.data == 5)
assert(node1.left == None)
assert(node1.right == None)
assert(node2.data == 4)
assert(node2.left == None)
assert(node2.right == None)
def test_insert(self):
assert(self.bst.tree == None)
self.bst.insert(self.node0)
assert(self.bst.tree.data == self.node0.data)
self.bst.insert(self.node3)
assert(self.bst.tree.right.data == self.node3.data)
self.bst.insert(self.node5)
assert(self.bst.tree.right.right.data == self.node5.data)
self.bst.insert(self.node4)
assert(self.bst.tree.right.right.left.data == self.node4.data)
self.bst.insert(self.node2)
assert(self.bst.tree.right.left.data == self.node2.data)
self.bst.insert(self.node1)
assert(self.bst.tree.right.left.left.data == self.node1.data)
self.bst.insert(self.node_4)
assert(self.bst.tree.left.data == self.node_4.data)
self.bst.insert(self.node_5)
assert(self.bst.tree.left.left.data == self.node_5.data)
self.bst.insert(self.node_2)
assert(self.bst.tree.left.right.data == self.node_2.data)
self.bst.insert(self.node_3)
assert(self.bst.tree.left.right.left.data == self.node_3.data)
self.bst.insert(self.node_1)
assert(self.bst.tree.left.right.right.data == self.node_1.data)
def test_remove(self):
assert(self.mytree.contains(0) == True)
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 6, 9])
self.mytree.remove(0)
assert(self.mytree.contains(0) == False)
assert(self.mytree.to_array() == [5, 2, 1, 3, 2.5, 4, 4.5, 8, 6, 9])
assert(self.mytree.contains(4.5) == True)
self.mytree.remove(4.5)
assert(self.mytree.contains(4.5) == False)
assert(self.mytree.to_array() == [5, 2, 1, 3, 2.5, 4, 8, 6, 9])
self.mytree.insert_data(0)
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 8, 6, 9])
assert(self.mytree.contains(1) == True)
self.mytree.remove(1)
assert(self.mytree.contains(1) == False)
assert(self.mytree.contains(0) == True)
assert(self.mytree.to_array() == [5, 2, 0, 3, 2.5, 4, 8, 6, 9])
# RESET
self.setUp()
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 6, 9])
self.mytree.remove(5)
assert(self.mytree.to_array() == [6, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 9] or
self.mytree.to_array() == [4.5, 2, 1, 0, 3, 2.5, 4, 8, 6, 9])
self.mytree.remove(6)
assert(self.mytree.to_array() == [4.5, 2, 1, 0, 3, 2.5, 4, 8, 9])
self.mytree.remove(4.5)
assert(self.mytree.to_array() == [4, 2, 1, 0, 3, 2.5, 8, 9])
self.mytree.remove(4)
assert(self.mytree.to_array() == [3, 2, 1, 0, 2.5, 8, 9] or
self.mytree.to_array() == [8, 2, 1, 0, 3, 2.5, 9])
self.mytree.remove(3)
assert(self.mytree.to_array() == [2.5, 2, 1, 0, 8, 9] or
self.mytree.to_array() == [8, 2, 1, 0, 2.5, 9])
self.mytree.remove(2.5)
assert(self.mytree.to_array() == [8, 2, 1, 0, 9])
self.mytree.remove(8)
assert(self.mytree.to_array() == [9, 2, 1, 0] or
self.mytree.to_array() == [2, 1, 0, 9])
self.mytree.remove(2)
assert(self.mytree.to_array() == [9, 1, 0] or
self.mytree.to_array() == [1, 0, 9])
self.mytree.remove(1)
assert(self.mytree.to_array() == [9, 0])
self.mytree.remove(0)
assert(self.mytree.to_array() == [9])
# RESET
self.setUp()
self.mytree.remove(4.5)
self.mytree.remove(4)
self.mytree.remove(2.5)
self.mytree.remove(3)
self.mytree.remove(0)
self.mytree.remove(1)
self.mytree.remove(2)
self.mytree.remove(6)
assert(self.mytree.to_array() == [5, 8, 9])
self.mytree.remove(5)
assert(self.mytree.to_array() == [8, 9])
self.mytree.remove(8)
assert(self.mytree.to_array() == [9])
def test_contains(self):
assert(self.mytree.contains(5) == True)
assert(self.mytree.contains(4) == True)
assert(self.mytree.contains(3) == True)
assert(self.mytree.contains(2) == True)
assert(self.mytree.contains(1) == True)
assert(self.mytree.contains(0) == True)
assert(self.mytree.contains(9) == True)
assert(self.mytree.contains(80) == False)
assert(self.mytree.contains(-5) == False)
assert(self.mytree.contains(-10) == False)
assert(self.mytree.contains(230) == False)
assert(self.mytree.contains(340) == False)
def test_get_node(self):
acquired_node0 = self.mytree.get_node(2)
acquired_node1 = self.mytree.get_node(8)
acquired_node2 = self.mytree.get_node(4.5)
acquired_node3 = self.mytree.get_node(45)
assert(acquired_node0.data == 2)
assert(acquired_node0.left.data == 1)
assert(acquired_node0.right.data == 3)
assert(acquired_node1.data == 8)
assert(acquired_node1.left.data == 6)
assert(acquired_node1.right.data == 9)
assert(acquired_node2.data == 4.5)
assert(acquired_node2.left == None)
assert(acquired_node2.right == None)
assert(acquired_node3 == None)
def test_to_array(self):
assert(self.bst.to_array() == [])
self.bst.insert_data(5)
self.bst.insert_data(2)
self.bst.insert_data(8)
self.bst.insert_data(6)
self.bst.insert_data(1)
self.bst.insert_data(3)
self.bst.insert_data(0)
self.bst.insert_data(4)
self.bst.insert_data(9)
assert(self.bst.to_array() == [5, 2, 1, 0, 3, 4, 8, 6, 9])
assert(self.mytree.to_array() == [5, 2, 1, 0, 3, 2.5, 4, 4.5, 8, 6, 9])
def test_size(self):
assert(self.mytree.size() == 11)
self.mytree.remove(4.5)
assert(self.mytree.size() == 10)
self.mytree.remove(4)
assert(self.mytree.size() == 9)
self.mytree.insert_data(30)
assert(self.mytree.size() == 10)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment