Created
July 18, 2019 15:52
-
-
Save KoStard/3fa0b29fcbcd69dd2c76dcbb6a488259 to your computer and use it in GitHub Desktop.
Implemented AVL tree.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Difference between height of left and right childs has to be <=1 | |
""" | |
class AVLTreeNode: | |
def __init__(self, val, parent=None): | |
self.val = val | |
self.left = None | |
self.right = None | |
self.parent = parent | |
self.left_depth = 0 | |
self.right_depth = 0 | |
def rotate_left(self): | |
if self.right: | |
np = self.right | |
if self.parent: | |
if self.parent.left == self: | |
self.parent.left = np | |
else: | |
self.parent.right = np | |
np.parent = self.parent | |
self.right.parent = self.parent | |
rl = self.right.left | |
self.right.left = self | |
self.right = rl | |
if rl: | |
rl.parent = self | |
self.parent = np | |
root = self | |
while root.parent: | |
root = root.parent | |
root.calculate_depth() | |
return np | |
def rotate_right(self): | |
if self.left: | |
np = self.left | |
if self.parent: | |
if self.parent.left == self: | |
self.parent.left = np | |
else: | |
self.parent.right = np | |
np.parent = self.parent | |
lr = self.left.right | |
self.left.right = self | |
self.left = lr | |
if lr: | |
lr.parent = self | |
self.parent = np | |
root = self | |
while root.parent: | |
root = root.parent | |
root.calculate_depth() | |
return np | |
@property | |
def root(self): | |
r = self | |
while r.parent: | |
r = r.parent | |
return r | |
def insert(self, val): | |
if val > self.val: | |
if self.right: | |
return self.right.insert(val) | |
else: | |
self.right = AVLTreeNode(val, self) | |
root = self | |
while root.parent: | |
root = root.parent | |
root.calculate_depth() | |
return self.fix() | |
elif val < self.val: | |
if self.left: | |
return self.left.insert(val) | |
else: | |
self.left = AVLTreeNode(val, self) | |
root = self | |
while root.parent: | |
root = root.parent | |
root.calculate_depth() | |
return self.fix() | |
else: | |
return self.root # Value already added | |
def remove(self): | |
""" | |
Replacing with in-order successor | |
""" | |
if self.right: | |
p = self | |
n = self.right | |
if n.left: | |
while n.left: | |
p = n | |
n = n.left | |
p.left = n.right | |
if n.right: | |
n.right.parent = p | |
n.left = self.left | |
n.right = self.right | |
else: | |
n.left = self.left | |
n.parent = self.parent | |
if not self.parent: | |
pass | |
elif self.parent.left == self: | |
self.parent.left = n | |
else: | |
self.parent.right = n | |
root = n | |
while root.parent: | |
root = root.parent | |
root.calculate_depth() | |
return n.fix() | |
else: | |
if not self.parent: | |
self.left.calculate_depth() | |
return self.left | |
elif self.parent.left == self: # Is left node | |
self.parent.left = self.left | |
else: | |
self.parent.right = self.left | |
if self.left: | |
self.left.parent = self.parent | |
root = self.parent | |
while root.parent: | |
root = root.parent | |
root.calculate_depth() | |
return self.parent.fix() | |
def remove_val(self, val): | |
if self.val == val: | |
return self.remove() | |
elif val < self.val: | |
if self.left: | |
return self.left.remove_val(val) | |
else: | |
if self.right: | |
return self.right.remove_val(val) | |
def __str__(self): | |
res = str(self.val) | |
if self.right: | |
res += '({})({})'.format(self.left or "", self.right) | |
elif self.left: | |
res += '({})'.format(self.left) | |
return res | |
def calculate_depth(self, current_depth=0): | |
left = right = current_depth | |
if self.left: | |
left = max(self.left.calculate_depth(current_depth + 1)) | |
if self.right: | |
right = max(self.right.calculate_depth(current_depth + 1)) | |
self.left_depth = left - current_depth | |
self.right_depth = right - current_depth | |
return left, right | |
def fix(self): | |
# return | |
n = self | |
if self.left_depth > self.right_depth + 1: | |
if self.left.right_depth > self.left.left_depth: | |
self.left.rotate_left() | |
return self.fix() | |
else: | |
n = self.rotate_right() | |
elif self.left_depth < self.right_depth - 1: | |
if self.right.left_depth > self.right.right_depth: | |
self.right.rotate_right() | |
return self.fix() | |
else: | |
n = self.rotate_left() | |
if self.parent: | |
return self.parent.fix() | |
else: | |
return n | |
def in_order_traversal(self, res): | |
if self.left: | |
self.left.in_order_traversal(res) | |
res.append(self.val) | |
if self.right: | |
self.right.in_order_traversal(res) | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Problem with calculate_depth - because of this the algorithm is working with O(N**2) complexity. You have to save depth and update it each time when rotating/inserting/removing.