-
-
Save shashankgroovy/7dcaf47950cb0d307ed438c96bccdbd6 to your computer and use it in GitHub Desktop.
Fun with binary trees
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import unittest | |
## binary tree problems | |
### A NODE OBJECT | |
class Node(object): | |
def __init__(self, data): | |
self.data = data | |
self.left = None | |
self.right = None | |
self.parent = None | |
def __repr__(self): | |
return str(self.data) | |
def __cmp__(self, other): | |
if self.data < other.data: | |
return -1 | |
elif self.data > other.data: | |
return 1 | |
else: return 0 | |
### A BINARY SEARCH TREE | |
class Tree(object): | |
def __init__(self, iterable=None): | |
self.root = None | |
self.size = 0 | |
if iterable: | |
for x in iterable: | |
self.insert(x) | |
def __len__(self): | |
return self.size | |
def insert(self, data): | |
node = Node(data) | |
if not self.root: | |
self.root = node | |
else: | |
tmp, parent = self.root, None | |
while tmp: | |
parent = tmp | |
tmp = tmp.left if node < tmp else tmp.right | |
# loop ends at correct position | |
if node < parent: | |
parent.left = node | |
else: | |
parent.right = node | |
node.parent = parent | |
self.size += 1 | |
return node | |
### Functions | |
""" | |
Inorder traveral of a binary tree | |
@param node Pointer to the root node of a binary tree | |
@returns a generator that generates nodes inorder | |
""" | |
def inorder(node): | |
if node.left: | |
for elem in inorder(node.left): | |
yield elem | |
yield node | |
if node.right: | |
for elem in inorder(node.right): | |
yield elem | |
""" | |
@param node Pointer to the root node of a binary tree | |
@returns height of a tree | |
""" | |
def getHeight(node): | |
return 0 if not node else max(getHeight(node.left), | |
getHeight(node.right)) + 1 | |
""" | |
@param node Pointer to the root node of a binary tree | |
@returns boolean - to indicate whether the tree is balanced or not | |
""" | |
def isBalanced(node): | |
if not node: | |
return True | |
diff = getHeight(node.left) - getHeight(node.right) | |
if abs(diff) > 1: | |
return False | |
return isBalanced(node.left) and isBalanced(node.right) | |
""" | |
@param node Pointer to the root node of a binary tree | |
@param level The level that needs to be returned (top level is 1 not 0) | |
@returns a generator that generates the nodes in the level i of a binary tree | |
""" | |
def getLevel(node, level): | |
if not node: | |
yield | |
if level == 1: | |
yield node | |
elif level > 1: | |
if node.left: | |
for elem in getLevel(node.left, level-1): | |
yield elem | |
if node.right: | |
for elem in getLevel(node.right, level-1): | |
yield elem | |
""" | |
@param nums array of sorted numbers | |
@returns a binary search tree with minimal height constructed from nums | |
""" | |
def constructTreeFromArray(nums): | |
t = Tree() | |
def _loop(array): | |
if not array: | |
return | |
mid = len(array)/2 | |
t.insert(array[mid]) | |
_loop(array[:mid]) | |
_loop(array[mid+1:]) | |
_loop(nums) | |
return t | |
""" | |
@param pointer to a tree root | |
@returns whether the tree is a binary search tree or not | |
""" | |
def isBST(node): | |
if not node: | |
return True | |
if node.left and node.left > node: | |
return False | |
if node.right and node.right < node: | |
return False | |
else: | |
return isBST(node.left) and isBST(node.right) | |
""" | |
@param t pointer to the root of the tree | |
@param node - the node for which the successor is required | |
@returns the successor of the node passed in the argument | |
""" | |
def getSuccessor(t, node): | |
right, left, parent = node.right, node.left, node.parent | |
if not right and not left: # leaf node | |
if parent.left and parent.left == node: # if left child | |
return parent | |
else: # if right child | |
tmp = parent | |
while tmp.parent and tmp != tmp.parent.left: | |
tmp = tmp.parent | |
return tmp.parent | |
else: | |
if right: # when there's a right child | |
tmp = right.left if right.left else right | |
while tmp.left: | |
tmp = tmp.left | |
return tmp | |
return parent # in case no right child exists | |
""" | |
@param node - root node of the tree | |
@data1 - data stored in node1 | |
@data2 - data stored in node2 | |
@returns lowest common ancestor in a binary search tree. Runs in O(logn) | |
""" | |
def LCAinBST(node, data1, data2): | |
if node.data > data1 and node.data > data2: | |
return LCAinBST(node.left, data1, data2) | |
elif node.data < data1 and node.data < data2: | |
return LCAinBST(node.right, data1, data2) | |
return node | |
""" | |
@param root - root node of the binary tree | |
@param data - the data which is to be found | |
@returns path(as a list) for the root to the element(inclusive) in a binary tree | |
""" | |
def getPath(root, data): | |
# inner function to handle the recursion | |
def _loop(node, path): | |
if not node: | |
return None | |
# append node to the path | |
path.append(node) | |
# if we've reached the destination | |
if node.data == data: | |
return True | |
# if not, then recurse on left and right subtrees | |
if (node.left and _loop(node.left, path)) or \ | |
(node.right and _loop(node.right, path)): | |
return True | |
# if we reach a leaf, pop this element | |
# as this is a deadend and return | |
path.pop() | |
return False | |
# init, call and return | |
path = [] | |
_loop(root, path) | |
return path | |
""" | |
@param root - the pointer to the root of the tree | |
@param total - the total sum | |
@returns - a list of paths that sum to the total in the binary tree | |
""" | |
def getSum(root, total): | |
# init a placeholder for path which is as long as the | |
# depth of the tree | |
depth = getHeight(root) | |
path = [0 for i in range(depth)] | |
# storehouse of all the valid paths | |
found_paths = [] | |
# this recursive fn finds all paths that sum to the total | |
def _loop(node, path, level): | |
# if you reach the end, return | |
if not node: | |
return | |
# else append node in the path | |
path[level], tmp = node.data, 0 | |
# iterate through the level to the lower most level | |
# and keep the running sum | |
for i in range(level, -1, -1): | |
tmp += path[i] | |
if tmp == total: | |
# if sum equals the total, append the path | |
# store the path | |
found_paths.append(path[i:level+1]) | |
# do the same for right and left subtrees | |
_loop(node.left, path, level+1) | |
_loop(node.right, path, level+1) | |
# set the current node itself to be 0 (or can be a big | |
# negative sentinel value) | |
path[level] = 0 | |
# init the recursive call - start from 0th level | |
_loop(root, path, 0) | |
return found_paths | |
""" | |
@param node - root node of the tree | |
@data1 - data stored in node1 | |
@data2 - data stored in node2 | |
@returns lowest common ancestor in a binary tree. Runs in O(n) | |
""" | |
def LCAinBinaryTree(root, data1, data2): | |
path1, path2, i = getPath(root, data1), getPath(root, data2), 0 | |
if not path1 or not path2: | |
return None | |
shorter = path1 if len(path1) <= len(path2) else path2 | |
while i < len(shorter) and path1[i] == path2[i]: | |
i += 1 | |
return shorter[i-1] | |
""" | |
@param root1 - pointer to the root node of tree1 | |
@param root2 - pointer to the root node of tree2 | |
@returns boolean to indicate where tree1 is a part of tree2 | |
""" | |
def checkSubtree(root1, root2): | |
if not root2: | |
return not root1 | |
if not root1: | |
return True | |
if root1.data != root2.data: | |
return checkSubtree(root1, root2.left) or \ | |
checkSubtree(root1, root2.right) | |
return checkSubtree(root1.left, root2.left) and \ | |
checkSubtree(root1.right, root2.right) | |
### UNITTESTS | |
class TestTreeMethods(unittest.TestCase): | |
def setUp(self): | |
self.balancedBst = Tree([15, 10, 20, 6, 12, 17, 23]) | |
self.unbalancedBst = Tree([6, 10, 12, 15, 17]) | |
self.binaryTree = Node(16) | |
self.binaryTree.left = Node(19) | |
self.binaryTree.right = Node(8) | |
# helper function to return data from an instance of Node class | |
def getData(self, x): | |
return x.data | |
def testInorder(self): | |
order = [x for x in inorder(self.balancedBst.root)] | |
self.assertEqual([6, 10, 12, 15, 17, 20, 23], map(self.getData, order)) | |
def testgetHeight(self): | |
self.assertEquals(getHeight(self.balancedBst.root), 3) | |
def testBalanced(self): | |
self.assertTrue(isBalanced(self.balancedBst.root)) | |
self.assertFalse(isBalanced(self.unbalancedBst.root)) | |
def testBSTConstruction(self): | |
array = [7, 9, 10, 15, 20, 24, 25] | |
generatedBst = constructTreeFromArray(array) | |
order = [x for x in inorder(generatedBst.root)] | |
self.assertEqual(3, getHeight(generatedBst.root)) | |
self.assertEqual(array, map(self.getData, order)) | |
def testGetLevel(self): | |
levels = [x for x in getLevel(self.balancedBst.root, 3)] | |
self.assertEqual(map(self.getData, levels), [6, 12, 17, 23]) | |
levels = [x for x in getLevel(self.unbalancedBst.root, 3)] | |
self.assertEqual(map(self.getData, levels), [12]) | |
def testIsBst(self): | |
self.assertFalse(isBST(self.binaryTree)) | |
self.assertTrue(isBST(self.balancedBst.root)) | |
def testGetPath(self): | |
existing_node = 23 | |
absent_node = 1000 | |
node_path = getPath(self.balancedBst.root, existing_node) | |
self.assertEqual(map(self.getData, node_path), [15, 20, 23]) | |
self.assertEqual(len(getPath(self.balancedBst.root, absent_node)), 0) #gives a blank list | |
def testLCAinBST(self): | |
self.assertEqual(LCAinBST(self.balancedBst.root, 6, 12).data, 10) | |
self.assertEqual(LCAinBST(self.balancedBst.root, 17, 23).data, 20) | |
self.assertEqual(LCAinBST(self.balancedBst.root, 12, 17).data, 15) | |
def testLCAinBinaryTree(self): | |
self.assertEqual(LCAinBinaryTree(self.balancedBst.root, 6, 12).data, 10) | |
self.assertEqual(LCAinBinaryTree(self.balancedBst.root, 17, 23).data, 20) | |
self.assertEqual(LCAinBinaryTree(self.balancedBst.root, 12, 17).data, 15) | |
def testSuccessor(self): | |
t1 = Tree([15, 10, 21, 7, 12, 20, 25, 8, 11, 13, 14, 23, 28]) | |
nodes = [x for x in inorder(t1.root)] | |
### testing leaf nodes that are left children | |
self.assertEqual(getSuccessor(t1.root, nodes[3]).data, 12) | |
self.assertEqual(getSuccessor(t1.root, nodes[8]).data, 21) | |
self.assertEqual(getSuccessor(t1.root, nodes[10]).data, 25) | |
### testing leaf nodes that are right children | |
self.assertEqual(getSuccessor(t1.root, nodes[1]).data, 10) | |
self.assertEqual(getSuccessor(t1.root, nodes[6]).data, 15) | |
self.assertEqual(getSuccessor(t1.root, nodes[-1]), None) | |
### testing non-leaf nodes | |
self.assertEqual(getSuccessor(t1.root, nodes[7]).data, 20) | |
self.assertEqual(getSuccessor(t1.root, nodes[9]).data, 23) | |
self.assertEqual(getSuccessor(t1.root, nodes[4]).data, 13) | |
def testGetPath(self): | |
tree = Tree([15, 8, 20, 6, 12, 17, 23]) | |
pathfor35 = getSum(tree.root, 35) | |
pathfor20 = getSum(tree.root, 20) | |
self.assertTrue([8, 12] in pathfor20) | |
self.assertTrue([20] in pathfor20) | |
self.assertTrue([15, 20] in pathfor35) | |
self.assertTrue([15, 8, 12] in pathfor35) | |
def testCheckSubtree(self): | |
tree2 = Tree([15, 9, 21, 7, 10, 20, 25, 23]) | |
tree3 = Tree([9, 7, 10]) | |
self.assertFalse(checkSubtree(Tree([8, 7, 10]).root, tree2.root)) | |
self.assertTrue(checkSubtree(Tree([9, 7, 10]).root, tree2.root)) | |
self.assertTrue(checkSubtree(Tree([21, 20]).root, tree2.root)) | |
self.assertTrue(checkSubtree(Tree([20]).root, tree2.root)) | |
self.assertFalse(checkSubtree(Tree([14]).root, tree2.root)) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment