Skip to content

Instantly share code, notes, and snippets.

@jsutch
Last active May 4, 2020 19:19
Show Gist options
  • Save jsutch/c68f5c2d87fd83a6f1104e9cae60a2b8 to your computer and use it in GitHub Desktop.
Save jsutch/c68f5c2d87fd83a6f1104e9cae60a2b8 to your computer and use it in GitHub Desktop.
BST tools in python
import random
class BSTNode(object):
def __init__(self, key):
self.val = key
self.left = None
self.right = None
# BST Methods
def insert(root, node):
"""
insert the given key into the BST
"""
# case where root is empty
if root is None:
root = node
else:
# case where tree exists and value goes right
if root.val < node.val:
if root.right is None:
root.right = node
else:
insert(root.right, node)
else:
if root.left is None:
root.left = node
else:
insert(root.left, node)
def search(root, key):
"""
Does key exist in the BST?
"""
if root.val is None:
return False
root.val == key:
return root.val
if root.val > key:
return search(root.left, key)
return search(root.right, key)
# Traversals
# Depth First traversals
def inorder(root):
"""
DFT returns nodes in non-decreasing order
print left -> root -> right
"""
if root:
inorder(root.left)
print(root.val)
inorder(root.right)
def preorder(root):
"""
DFT used to create a copy of the tree
print root -> left -> right
"""
if root:
print(root.val)
preorder(root.left)
preorder(root.right)
def postorder(root):
"""
DFT - used to delete a tree
print left -> right -> root
"""
if root:
postorder(root.left)
postorder(root.right)
print(root.val)
# Breadth first Traversals
def height(node):
"""
Find height of node. Needed for levelorder traversal
"""
if node is None:
return 0
else:
lh = height(node.left)
rh = height(node.right)
#
if lh > rh:
return lh + 1
else:
return rh + 1
def givenlevel(root, level):
"""
find the givenlevel of the level. Needed for levelorder traversal
"""
if root is None:
return
if level == 1:
print(root.val)
elif level > 1:
givenlevel(root.left, level -1)
givenlevel(root.right, level -1)
def levelorder(root):
"""
a BFT that prints all nodes per level
"""
h = height(root)
for i in range(1, h + 1):
givenlevel(root, i)
def printlevel(root):
[(print("=== Level", x, "==="), givenlevel(root, x)) for x in range(1,height(root))]
# Balancing a tree
# is a tree balanced?
def isbalanced(root):
if root is None:
return True
return is_balanced(root.right) and is_balanced(root.left) and abs(height(root.left) - height(root.right)) <= 1
def store(root, nodearr):
"""
recreate the bst as an inorder array.
needed by balance
"""
if not root:
return
store(root.left)
nodearr.append(root.val)
store(root.right)
def build(nodearr):
"""
rebuild the bst from the inorder array
needed by balance
"""
if not noderarr:
return None
mid = len(nodearr) // 2
node = BSTNode(nodearr[mid])
node.left = newBST(nodearr[:mid])
node.right = newBST(nodearr[mid + 1:])
return node
def balance(root):
"""
- get inorder traversal of existing BST as an array
- create a new BST object sorted by the midpoint
- return the object
"""
nodes = []
store(root, nodes)
return newBST(nodes)
# Helper code
# create and populate a tree
foo = BSTNode(500)
[insert(foo, BSTNode(random.randint(1,1024))) for x in range(20)]
Examples:
In [132]: foo = BSTNode(500)
In [133]: [insert(foo, BSTNode(random.randint(1,1024))) for x in range(20)]
Out[133]: [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
In [134]: printlevel(foo)
=== Level 1 ===
500
=== Level 2 ===
262
540
=== Level 3 ===
155
288
987
=== Level 4 ===
29
228
277
485
831
=== Level 5 ===
71
224
547
872
=== Level 6 ===
71
154
554
841
Out[134]: [(None, None), (None, None), (None, None), (None, None), (None, None), (None, None)]
In [135]: inorder(foo)
99
167
255
270
333
390
413
421
436
500
519
568
574
630
656
671
867
922
943
985
997
In [136]: preorder(foo)
500
167
99
413
333
270
255
390
436
421
867
630
568
519
574
671
656
997
985
922
943
In [137]: postorder(foo)
99
255
270
390
333
421
436
413
167
519
574
568
656
671
630
943
922
985
997
867
500
In [138]: levelorder(foo)
500
167
867
99
413
630
997
333
436
568
671
985
270
390
421
519
574
656
922
255
943
In [139]: isbalanced(foo)
Out[139]: False
In [140]: foo = balance(foo)
In [141]: isbalanced(foo)
Out[141]: True
In [142]: inorder(foo)
99
167
255
270
333
390
413
421
436
500
519
568
574
630
656
671
867
922
943
985
997
In [143]: printlevel(foo)
=== Level 1 ===
500
=== Level 2 ===
262
540
=== Level 3 ===
155
288
987
=== Level 4 ===
29
228
277
485
831
=== Level 5 ===
71
224
547
872
=== Level 6 ===
71
154
554
841
Out[143]: [(None, None), (None, None), (None, None), (None, None), (None, None), (None, None)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment