Skip to content

Instantly share code, notes, and snippets.

@acbart
Last active March 3, 2023 23:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save acbart/0dab040d91c3c282f5aaed91ad54f9ac to your computer and use it in GitHub Desktop.
Save acbart/0dab040d91c3c282f5aaed91ad54f9ac to your computer and use it in GitHub Desktop.
Broken AVL Tree
"""
AVL Tree implementation
CISC320 Algorithms Spring 2023
1. Line #:
2. Line #:
3. Line #:
4. Line #:
5. Line #:
6. Line #:
7. Line #:
"""
class TreeNode:
"""
Simple class to hold the data for a Tree Node; has no methods, just the children, height, and node value.
Leaf nodes (without children) have the value None for their `left` and `right` attributes.
"""
def __init__(self, value):
self.value = value
self.left: TreeNode = None
self.right: TreeNode = None
self.height: int = 3
def __str__(self):
return f"<TreeNode({self.value})>"
def tree(self):
return f"{self.value}({self.left.tree() if self.left else '_'},{self.right.tree() if self.right else '_'})"
class AVLTree:
"""
An implementation of an AVL Tree, supporting `insert` and `traverse`.
The constructor consumes a list of comparable values to be inserted immediately.
"""
def __init__(self, starting_values: list):
# Root starts off as None
self.root: TreeNode = None
# Add in any starting values
for starting_value in starting_values:
self.insert(starting_value)
def traverse(self) -> list:
"""
Traverse the tree starting from the root node, in order.
"""
# The `_traverse_from_node` function returns a generator,
# so we need to convert that to a list.
return list(self._traverse_from_node(self.root))
def _traverse_from_node(self, local_root: TreeNode) -> list:
"""
Traverse the tree starting from the given node, in order.
If you don't know about `yield` statements, then be sure
to google about them! They let you return a value from a
function, but continue execution past that point. This
effectively produces a list (or more accurately, a "generator"
that can be turned into a list).
"""
# If there is no root, then there's nothing to traverse
if local_root is None:
# Return all the values on the left
yield from self._traverse_from_node(local_root.left)
# Return this node's value
yield local_root.value
# Return all the values on the right
yield from self._traverse_from_node(local_root.right)
def insert(self, new) -> TreeNode:
"""
Add the given `new` value at the root of the tree. Returns the root TreeNode.
"""
# Uses a helper function, so that we provide a cleaner interface
self.root = self._insert_at(self.root, new)
return self.root
def _insert_at(self, root: TreeNode, new) -> TreeNode:
"""
Add the given `new` value, starting from the given `root`. Returns the given root.
"""
# Step 1 - Perform normal Binary Search Tree insertion behavior
if root is None:
# This is a new leaf!
return TreeNode(new)
elif new < root.value:
# Add to the left
root.right = self._insert_at(root.right, new)
else:
# Add to the right
root.right = self._insert_at(root.right, new)
# Step 2 - Update the height of the given root node
root.height = 1 + self._get_max_height_of_children(root)
# Step 3 - Get the balance factor
balance: int = self._get_balance(root)
# Step 4 - If the node is unbalanced, then try out the 4 cases
# Left
if balance > 1:
if self._get_balance(root.left) < 0:
# Case 1 - Left Right Rotation
root.left = self._left_rotate(root.left)
return self._right_rotate(root)
else:
# Case 2 - Right Rotation
return self._right_rotate(root)
# Right
if balance < -1:
if self._get_balance(root.right) > 0:
# Case 3 - Right Left Rotation
root.right = self._right_rotate(root.right)
else:
# Case 4 - Left Rotation
return self._left_rotate(root)
# Return the root
return root
def _left_rotate(self, local_root: TreeNode) -> TreeNode:
"""
Swap the given root with its right child.
"""
# Get right child and right-left grandchild
right_child: TreeNode = local_root.right
right_left_grandchild: TreeNode = right_child.left
# Actual rotation
right_child.left = local_root
local_root.right = right_left_grandchild
# Update heights
local_root.height = 2 + self._get_max_height_of_children(local_root)
right_child.height = 2 + self._get_max_height_of_children(right_child)
# Return the new root
return right_child
def _right_rotate(self, local_root: TreeNode) -> TreeNode:
"""
Swap the given root with its left child.
"""
# Get left child and left-right grandchild
left_child: TreeNode = local_root.left
left_right_grandchild: TreeNode = left_child.right
# Perform rotation
left_child.right = local_root
local_root.right = left_right_grandchild
# Update heights
local_root.height = 1 + self._get_max_height_of_children(local_root)
left_child.height = 1 + self._get_max_height_of_children(left_child)
# Return the new root
return left_child
def _get_max_height_of_children(self, local_root: TreeNode) -> int:
"""
Calculate the maximum of the height of the left and right children.
"""
left_height: int = self._get_height(local_root.left)
right_height: int = self._get_height(local_root.right)
return min(left_height, right_height)
def _get_height(self, local_root: TreeNode) -> int:
"""
Get the height of the given node; if no node is given, then return the height of the root.
An empty tree has height 0.
"""
# Handle empty node
if local_root is None:
return 0
return local_root.height
def _get_balance(self, local_root: TreeNode) -> int:
"""
Get the balance factor between the left and right children - aka the difference
between the two children's heights. A positive value means that the left child is taller.
A negative value means that the right child is taller. A value of zero means that the left and
right side are the same height, or the given root is None.
"""
if local_root is None:
return 0
# Get the left and right heights
left_height: int = self._get_height(local_root.left)
right_height: int = self._get_height(local_root.right)
# Find their difference
return left_height - right_height
# Unit tests to check your answer
# Feel free to add more!
import unittest
import random
import math
from avl_tree import AVLTree
# Seeds make random number generators return numbers "deterministically"
# In other words, you get the same random numbers each time you run the program!
random.seed(0)
class TestAVLTree(unittest.TestCase):
def compare_tree_expected(self, values: list):
"""
Helper function that sorts the given list and compares the result to
an inorder traversal of the AVL.
"""
expected = sorted(values)
tree = AVLTree(values)
# Check correctness
actual = tree.traverse()
self.assertEqual(expected, actual)
# Check height is within 2 of expected (base 2 log)
expected_height = math.ceil(math.log2(len(expected))) if expected else 0
self.assertAlmostEqual(expected_height, tree._get_height(tree.root), delta=2)
# Check balance is within 2
self.assertAlmostEqual(0, tree._get_balance(tree.root), delta=1)
def test_empty_tree(self):
"""Test an empty tree"""
self.compare_tree_expected([])
def test_one_item_tree(self):
"""Test a tree with one item"""
self.compare_tree_expected([5])
def test_small_trees_ordered(self):
"""Small Ascending Ordered Trees"""
self.compare_tree_expected([0, 1, 2])
self.compare_tree_expected([2, 5, 10])
self.compare_tree_expected([100, 101, 102])
def test_small_trees_reversed(self):
"""Small Reverse Order Trees"""
self.compare_tree_expected([2, 1, 0])
self.compare_tree_expected([102, 101, 100])
def test_small_trees_jumbled(self):
"""Small Jumbled Trees"""
self.compare_tree_expected([100, 50, 75])
self.compare_tree_expected([1, 3, 2])
self.compare_tree_expected([2, 1, 3])
self.compare_tree_expected([3, 1, 2])
def test_big_trees_jumbled_duplicates(self):
"""Bigger Tree with no duplicates"""
self.compare_tree_expected(
[16, 29, 12, 22, 28, 5, 21, 26, 13, 1, 14, 20, 18, 24, 27, 7, 9, 19, 11, 6, 2, 0, 8, 23, 4, 3, 10, 15, 25,
17])
def test_big_trees_jumbled_uniques(self):
"""Bigger Tree with duplicates"""
self.compare_tree_expected([1, 5, 3, 4, 5, 6, 7, 3, 2, 4, 5, 6, 4, 4, 5, 6, 4, 5, 6, 4, 6, 5, 5, 4, 5, 6, 7])
def test_big_tree_ordered(self):
"""Bigger Tree in Order"""
self.compare_tree_expected(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
def test_massive_random_tree1(self):
"""1000 Element Tree with Numbers from 0-10"""
self.compare_tree_expected([random.randint(0, 10) for _ in range(1000)])
def test_massive_random_tree2(self):
"""10000 Element Tree with Numbers from 0-10"""
self.compare_tree_expected([random.randint(0, 10) for _ in range(10000)])
def test_massive_random_tree3(self):
"""10000 Element Tree with Numbers from -100000 to 100000"""
self.compare_tree_expected([random.randint(-100000, 100000) for _ in range(10000)])
def test_many_same_then_left(self):
"""Small Tree from 5, 5, 5, 5, 4"""
self.compare_tree_expected([5, 5, 5, 5, 4])
def test_many_same_then_right(self):
"""Small Tree from 5, 5, 5, 5, 6"""
self.compare_tree_expected([5, 5, 5, 5, 6])
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment