Skip to content

Instantly share code, notes, and snippets.

@QuentinWach
Created June 9, 2020 08:18
Show Gist options
  • Save QuentinWach/2f02032a02744654d2f09690408b51b2 to your computer and use it in GitHub Desktop.
Save QuentinWach/2f02032a02744654d2f09690408b51b2 to your computer and use it in GitHub Desktop.
# Quentin Wach, AVL Tree
class Node:
def __init__(self, key, leftChild=None, rightChild=None, parent=None):
self.key = key
self.leftChild = leftChild
self.rightChild = rightChild
self.parent = parent
self.height = 1
class AVLTree(object):
def __init__(self, key):
self.root = Node(key)
def insert(self, key):
# add new node to the tree
walkNode = self.root
if key < walkNode.key:
# left
if walkNode.leftChild != None:
# go deeper
self.root = walkNode.leftChild
self.insert(key)
else:
# add leaf node
self.root.leftChild = Node(key)
P = self.root
self.root = self.root.leftChild
self.root.parent = P
else:
# right
if walkNode.rightChild != None:
# go deeper
self.root = walkNode.rightChild
self.insert(key)
else:
# add lead node
self.root.rightChild = Node(key)
P = self.root
self.root = self.root.rightChild
self.root.parent = P
# rebalance the tree
self.rebalance(key)
def rebalance(self, key):
# update node height
self.root.height = 1 + max(self.getHeight(self.root.leftChild),
self.getHeight(self.root.rightChild))
# check balance of node
balanceRatio = self.balance()
# useful info for debugging
#print("KEY :" + str(self.root.key))
#print("HEIGHT :" + str(self.root.height))
#print("BALANCE :" + str(balanceRatio))
# rebalance if inbalanced at node
if balanceRatio > 1:
if key < self.root.leftChild.key:
#print("A, rotating right at " + str(self.root.key))
self.root = self.rightRot(self.root)
else:
#print("B, rotating left-right at " + str(self.root.key))
self.root.leftChild = self.leftRot(self.root.leftChild)
self.root = self.rightRot(self.root)
if balanceRatio < -1:
if key > self.root.rightChild.key:
#print("C, rotating left at " + str(self.root.key))
self.root = self.leftRot(self.root)
else:
#print("D, rotating right-left at " + str(self.root.key))
self.root.rightChild = self.rightRot(self.root.rightChild)
self.root = self.leftRot(self.root)
# iterate upwords to the next parent node
if self.root.parent != None:
self.root = self.root.parent
self.rebalance(key)
def getHeight(self, node):
if node == None:
return 0
return node.height
def balance(self):
if self.root is None:
return 0
return self.getHeight(self.root.leftChild) - self.getHeight(self.root.rightChild)
def leftRot(self, z):
"""
z
/ \
a y
/ \
b x
"""
y = z.rightChild
if z.parent != None:
y.parent = z.parent
# Verknüpfung von z zum Baum
if z.parent.leftChild == z:
z.parent.leftChild = y
if z.parent.rightChild == z:
z.parent.rightChild = y
else:
y.parent = None
self.root = y
b = y.leftChild
y.leftChild = z
z.parent = y
z.rightChild = b
if b != None:
b.parent = z
z.height = 1 + max(self.getHeight(z.leftChild),
self.getHeight(z.rightChild))
y.height = 1 + max(self.getHeight(y.leftChild),
self.getHeight(y.rightChild))
return y
def rightRot(self, z):
"""
z
/ \
y a
/ \
x b
"""
y = z.leftChild
if z.parent != None:
y.parent = z.parent
# Verknüpfung von z zum Baum
if z.parent.leftChild == z:
z.parent.leftChild = y
if z.parent.rightChild == z:
z.parent.rightChild = y
else:
y.parent = None
self.root = y
b = y.rightChild
y.rightChild = z
z.parent = y
z.leftChild = b
if b != None:
b.parent = z
z.height = 1 + max(self.getHeight(z.leftChild),
self.getHeight(z.rightChild))
y.height = 1 + max(self.getHeight(y.leftChild),
self.getHeight(y.rightChild))
return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment