Skip to content

Instantly share code, notes, and snippets.

@shashankgroovy
Forked from prakhar1989/trees.py
Created July 29, 2016 05:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shashankgroovy/7dcaf47950cb0d307ed438c96bccdbd6 to your computer and use it in GitHub Desktop.
Save shashankgroovy/7dcaf47950cb0d307ed438c96bccdbd6 to your computer and use it in GitHub Desktop.
Fun with binary trees
#!/usr/bin/python
import unittest
## binary tree problems
### A NODE OBJECT
class Node(object):
def __init__(self, data):
self.data = data
self.left = None
self.right = None
self.parent = None
def __repr__(self):
return str(self.data)
def __cmp__(self, other):
if self.data < other.data:
return -1
elif self.data > other.data:
return 1
else: return 0
### A BINARY SEARCH TREE
class Tree(object):
def __init__(self, iterable=None):
self.root = None
self.size = 0
if iterable:
for x in iterable:
self.insert(x)
def __len__(self):
return self.size
def insert(self, data):
node = Node(data)
if not self.root:
self.root = node
else:
tmp, parent = self.root, None
while tmp:
parent = tmp
tmp = tmp.left if node < tmp else tmp.right
# loop ends at correct position
if node < parent:
parent.left = node
else:
parent.right = node
node.parent = parent
self.size += 1
return node
### Functions
"""
Inorder traveral of a binary tree
@param node Pointer to the root node of a binary tree
@returns a generator that generates nodes inorder
"""
def inorder(node):
if node.left:
for elem in inorder(node.left):
yield elem
yield node
if node.right:
for elem in inorder(node.right):
yield elem
"""
@param node Pointer to the root node of a binary tree
@returns height of a tree
"""
def getHeight(node):
return 0 if not node else max(getHeight(node.left),
getHeight(node.right)) + 1
"""
@param node Pointer to the root node of a binary tree
@returns boolean - to indicate whether the tree is balanced or not
"""
def isBalanced(node):
if not node:
return True
diff = getHeight(node.left) - getHeight(node.right)
if abs(diff) > 1:
return False
return isBalanced(node.left) and isBalanced(node.right)
"""
@param node Pointer to the root node of a binary tree
@param level The level that needs to be returned (top level is 1 not 0)
@returns a generator that generates the nodes in the level i of a binary tree
"""
def getLevel(node, level):
if not node:
yield
if level == 1:
yield node
elif level > 1:
if node.left:
for elem in getLevel(node.left, level-1):
yield elem
if node.right:
for elem in getLevel(node.right, level-1):
yield elem
"""
@param nums array of sorted numbers
@returns a binary search tree with minimal height constructed from nums
"""
def constructTreeFromArray(nums):
t = Tree()
def _loop(array):
if not array:
return
mid = len(array)/2
t.insert(array[mid])
_loop(array[:mid])
_loop(array[mid+1:])
_loop(nums)
return t
"""
@param pointer to a tree root
@returns whether the tree is a binary search tree or not
"""
def isBST(node):
if not node:
return True
if node.left and node.left > node:
return False
if node.right and node.right < node:
return False
else:
return isBST(node.left) and isBST(node.right)
"""
@param t pointer to the root of the tree
@param node - the node for which the successor is required
@returns the successor of the node passed in the argument
"""
def getSuccessor(t, node):
right, left, parent = node.right, node.left, node.parent
if not right and not left: # leaf node
if parent.left and parent.left == node: # if left child
return parent
else: # if right child
tmp = parent
while tmp.parent and tmp != tmp.parent.left:
tmp = tmp.parent
return tmp.parent
else:
if right: # when there's a right child
tmp = right.left if right.left else right
while tmp.left:
tmp = tmp.left
return tmp
return parent # in case no right child exists
"""
@param node - root node of the tree
@data1 - data stored in node1
@data2 - data stored in node2
@returns lowest common ancestor in a binary search tree. Runs in O(logn)
"""
def LCAinBST(node, data1, data2):
if node.data > data1 and node.data > data2:
return LCAinBST(node.left, data1, data2)
elif node.data < data1 and node.data < data2:
return LCAinBST(node.right, data1, data2)
return node
"""
@param root - root node of the binary tree
@param data - the data which is to be found
@returns path(as a list) for the root to the element(inclusive) in a binary tree
"""
def getPath(root, data):
# inner function to handle the recursion
def _loop(node, path):
if not node:
return None
# append node to the path
path.append(node)
# if we've reached the destination
if node.data == data:
return True
# if not, then recurse on left and right subtrees
if (node.left and _loop(node.left, path)) or \
(node.right and _loop(node.right, path)):
return True
# if we reach a leaf, pop this element
# as this is a deadend and return
path.pop()
return False
# init, call and return
path = []
_loop(root, path)
return path
"""
@param root - the pointer to the root of the tree
@param total - the total sum
@returns - a list of paths that sum to the total in the binary tree
"""
def getSum(root, total):
# init a placeholder for path which is as long as the
# depth of the tree
depth = getHeight(root)
path = [0 for i in range(depth)]
# storehouse of all the valid paths
found_paths = []
# this recursive fn finds all paths that sum to the total
def _loop(node, path, level):
# if you reach the end, return
if not node:
return
# else append node in the path
path[level], tmp = node.data, 0
# iterate through the level to the lower most level
# and keep the running sum
for i in range(level, -1, -1):
tmp += path[i]
if tmp == total:
# if sum equals the total, append the path
# store the path
found_paths.append(path[i:level+1])
# do the same for right and left subtrees
_loop(node.left, path, level+1)
_loop(node.right, path, level+1)
# set the current node itself to be 0 (or can be a big
# negative sentinel value)
path[level] = 0
# init the recursive call - start from 0th level
_loop(root, path, 0)
return found_paths
"""
@param node - root node of the tree
@data1 - data stored in node1
@data2 - data stored in node2
@returns lowest common ancestor in a binary tree. Runs in O(n)
"""
def LCAinBinaryTree(root, data1, data2):
path1, path2, i = getPath(root, data1), getPath(root, data2), 0
if not path1 or not path2:
return None
shorter = path1 if len(path1) <= len(path2) else path2
while i < len(shorter) and path1[i] == path2[i]:
i += 1
return shorter[i-1]
"""
@param root1 - pointer to the root node of tree1
@param root2 - pointer to the root node of tree2
@returns boolean to indicate where tree1 is a part of tree2
"""
def checkSubtree(root1, root2):
if not root2:
return not root1
if not root1:
return True
if root1.data != root2.data:
return checkSubtree(root1, root2.left) or \
checkSubtree(root1, root2.right)
return checkSubtree(root1.left, root2.left) and \
checkSubtree(root1.right, root2.right)
### UNITTESTS
class TestTreeMethods(unittest.TestCase):
def setUp(self):
self.balancedBst = Tree([15, 10, 20, 6, 12, 17, 23])
self.unbalancedBst = Tree([6, 10, 12, 15, 17])
self.binaryTree = Node(16)
self.binaryTree.left = Node(19)
self.binaryTree.right = Node(8)
# helper function to return data from an instance of Node class
def getData(self, x):
return x.data
def testInorder(self):
order = [x for x in inorder(self.balancedBst.root)]
self.assertEqual([6, 10, 12, 15, 17, 20, 23], map(self.getData, order))
def testgetHeight(self):
self.assertEquals(getHeight(self.balancedBst.root), 3)
def testBalanced(self):
self.assertTrue(isBalanced(self.balancedBst.root))
self.assertFalse(isBalanced(self.unbalancedBst.root))
def testBSTConstruction(self):
array = [7, 9, 10, 15, 20, 24, 25]
generatedBst = constructTreeFromArray(array)
order = [x for x in inorder(generatedBst.root)]
self.assertEqual(3, getHeight(generatedBst.root))
self.assertEqual(array, map(self.getData, order))
def testGetLevel(self):
levels = [x for x in getLevel(self.balancedBst.root, 3)]
self.assertEqual(map(self.getData, levels), [6, 12, 17, 23])
levels = [x for x in getLevel(self.unbalancedBst.root, 3)]
self.assertEqual(map(self.getData, levels), [12])
def testIsBst(self):
self.assertFalse(isBST(self.binaryTree))
self.assertTrue(isBST(self.balancedBst.root))
def testGetPath(self):
existing_node = 23
absent_node = 1000
node_path = getPath(self.balancedBst.root, existing_node)
self.assertEqual(map(self.getData, node_path), [15, 20, 23])
self.assertEqual(len(getPath(self.balancedBst.root, absent_node)), 0) #gives a blank list
def testLCAinBST(self):
self.assertEqual(LCAinBST(self.balancedBst.root, 6, 12).data, 10)
self.assertEqual(LCAinBST(self.balancedBst.root, 17, 23).data, 20)
self.assertEqual(LCAinBST(self.balancedBst.root, 12, 17).data, 15)
def testLCAinBinaryTree(self):
self.assertEqual(LCAinBinaryTree(self.balancedBst.root, 6, 12).data, 10)
self.assertEqual(LCAinBinaryTree(self.balancedBst.root, 17, 23).data, 20)
self.assertEqual(LCAinBinaryTree(self.balancedBst.root, 12, 17).data, 15)
def testSuccessor(self):
t1 = Tree([15, 10, 21, 7, 12, 20, 25, 8, 11, 13, 14, 23, 28])
nodes = [x for x in inorder(t1.root)]
### testing leaf nodes that are left children
self.assertEqual(getSuccessor(t1.root, nodes[3]).data, 12)
self.assertEqual(getSuccessor(t1.root, nodes[8]).data, 21)
self.assertEqual(getSuccessor(t1.root, nodes[10]).data, 25)
### testing leaf nodes that are right children
self.assertEqual(getSuccessor(t1.root, nodes[1]).data, 10)
self.assertEqual(getSuccessor(t1.root, nodes[6]).data, 15)
self.assertEqual(getSuccessor(t1.root, nodes[-1]), None)
### testing non-leaf nodes
self.assertEqual(getSuccessor(t1.root, nodes[7]).data, 20)
self.assertEqual(getSuccessor(t1.root, nodes[9]).data, 23)
self.assertEqual(getSuccessor(t1.root, nodes[4]).data, 13)
def testGetPath(self):
tree = Tree([15, 8, 20, 6, 12, 17, 23])
pathfor35 = getSum(tree.root, 35)
pathfor20 = getSum(tree.root, 20)
self.assertTrue([8, 12] in pathfor20)
self.assertTrue([20] in pathfor20)
self.assertTrue([15, 20] in pathfor35)
self.assertTrue([15, 8, 12] in pathfor35)
def testCheckSubtree(self):
tree2 = Tree([15, 9, 21, 7, 10, 20, 25, 23])
tree3 = Tree([9, 7, 10])
self.assertFalse(checkSubtree(Tree([8, 7, 10]).root, tree2.root))
self.assertTrue(checkSubtree(Tree([9, 7, 10]).root, tree2.root))
self.assertTrue(checkSubtree(Tree([21, 20]).root, tree2.root))
self.assertTrue(checkSubtree(Tree([20]).root, tree2.root))
self.assertFalse(checkSubtree(Tree([14]).root, tree2.root))
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment