Skip to content

Instantly share code, notes, and snippets.

@KoStard
Created July 18, 2019 15:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KoStard/3fa0b29fcbcd69dd2c76dcbb6a488259 to your computer and use it in GitHub Desktop.
Save KoStard/3fa0b29fcbcd69dd2c76dcbb6a488259 to your computer and use it in GitHub Desktop.
Implemented AVL tree.
"""
Difference between height of left and right childs has to be <=1
"""
class AVLTreeNode:
def __init__(self, val, parent=None):
self.val = val
self.left = None
self.right = None
self.parent = parent
self.left_depth = 0
self.right_depth = 0
def rotate_left(self):
if self.right:
np = self.right
if self.parent:
if self.parent.left == self:
self.parent.left = np
else:
self.parent.right = np
np.parent = self.parent
self.right.parent = self.parent
rl = self.right.left
self.right.left = self
self.right = rl
if rl:
rl.parent = self
self.parent = np
root = self
while root.parent:
root = root.parent
root.calculate_depth()
return np
def rotate_right(self):
if self.left:
np = self.left
if self.parent:
if self.parent.left == self:
self.parent.left = np
else:
self.parent.right = np
np.parent = self.parent
lr = self.left.right
self.left.right = self
self.left = lr
if lr:
lr.parent = self
self.parent = np
root = self
while root.parent:
root = root.parent
root.calculate_depth()
return np
@property
def root(self):
r = self
while r.parent:
r = r.parent
return r
def insert(self, val):
if val > self.val:
if self.right:
return self.right.insert(val)
else:
self.right = AVLTreeNode(val, self)
root = self
while root.parent:
root = root.parent
root.calculate_depth()
return self.fix()
elif val < self.val:
if self.left:
return self.left.insert(val)
else:
self.left = AVLTreeNode(val, self)
root = self
while root.parent:
root = root.parent
root.calculate_depth()
return self.fix()
else:
return self.root # Value already added
def remove(self):
"""
Replacing with in-order successor
"""
if self.right:
p = self
n = self.right
if n.left:
while n.left:
p = n
n = n.left
p.left = n.right
if n.right:
n.right.parent = p
n.left = self.left
n.right = self.right
else:
n.left = self.left
n.parent = self.parent
if not self.parent:
pass
elif self.parent.left == self:
self.parent.left = n
else:
self.parent.right = n
root = n
while root.parent:
root = root.parent
root.calculate_depth()
return n.fix()
else:
if not self.parent:
self.left.calculate_depth()
return self.left
elif self.parent.left == self: # Is left node
self.parent.left = self.left
else:
self.parent.right = self.left
if self.left:
self.left.parent = self.parent
root = self.parent
while root.parent:
root = root.parent
root.calculate_depth()
return self.parent.fix()
def remove_val(self, val):
if self.val == val:
return self.remove()
elif val < self.val:
if self.left:
return self.left.remove_val(val)
else:
if self.right:
return self.right.remove_val(val)
def __str__(self):
res = str(self.val)
if self.right:
res += '({})({})'.format(self.left or "", self.right)
elif self.left:
res += '({})'.format(self.left)
return res
def calculate_depth(self, current_depth=0):
left = right = current_depth
if self.left:
left = max(self.left.calculate_depth(current_depth + 1))
if self.right:
right = max(self.right.calculate_depth(current_depth + 1))
self.left_depth = left - current_depth
self.right_depth = right - current_depth
return left, right
def fix(self):
# return
n = self
if self.left_depth > self.right_depth + 1:
if self.left.right_depth > self.left.left_depth:
self.left.rotate_left()
return self.fix()
else:
n = self.rotate_right()
elif self.left_depth < self.right_depth - 1:
if self.right.left_depth > self.right.right_depth:
self.right.rotate_right()
return self.fix()
else:
n = self.rotate_left()
if self.parent:
return self.parent.fix()
else:
return n
def in_order_traversal(self, res):
if self.left:
self.left.in_order_traversal(res)
res.append(self.val)
if self.right:
self.right.in_order_traversal(res)
return res
@KoStard
Copy link
Author

KoStard commented Jul 19, 2019

Problem with calculate_depth - because of this the algorithm is working with O(N**2) complexity. You have to save depth and update it each time when rotating/inserting/removing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment