Skip to content

Instantly share code, notes, and snippets.

@codefever
Last active August 24, 2019 09:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save codefever/4442fd7130700c0b05fb1ca1ac2f3f4c to your computer and use it in GitHub Desktop.
Save codefever/4442fd7130700c0b05fb1ca1ac2f3f4c to your computer and use it in GitHub Desktop.
RBT in python
#!/usr/bin/env python
import sys
# RBT != 仙踪林
# https://en.wikipedia.org/wiki/Red%E2%80%93black_tree
class RBNode(object):
def __init__(self, val, is_red=False, left=None, right=None, parent=None):
self.val = val
self.is_red = is_red
self.left = left
self.right = right
self.parent = parent
def red(self):
return self.is_red
def black(self):
return not self.is_red
def left_black(self):
return True if self.left is None or self.left.black() else False
def right_black(self):
return True if self.right is None or self.right.black() else False
def set_red(self):
self.is_red = True
def set_black(self):
self.is_red = False
@property
def grandparent(self):
if self.parent is not None:
return self.parent.parent
else:
return None
@property
def uncle(self):
gp = self.grandparent
if gp is not None:
return gp.right if self.parent == gp.left else gp.left
return None
@property
def sibling(self):
p = self.parent
if p is not None:
return p.left if self == p.right else p.right
return None
def rotate_left(self):
p = self.parent
assert self.right is not None
n = self.right
self.right, n.left = n.left, self
n.parent = p
self.parent = n
if self.right is not None:
self.right.parent = self
if p is not None:
if p.left == self:
p.left = n
else:
p.right = n
return n
def __repr__(self):
tail = '(R)' if self.is_red else '(B)'
return str(self.val) + tail
def check_me(self):
bs = 1
if self.is_red:
assert (not self.left or not self.left.is_red)
assert (not self.right or not self.right.is_red)
bs -= 1
bl = 0
if self.left is not None:
assert self.left.val <= self.val
bl = self.left.check_me()
br = 0
if self.right is not None:
assert self.right.val >= self.val
br = self.right.check_me()
assert bl == br
return bl + bs
def print_me(self):
q = [self]
level = 0
while True:
new_q = []
cnt = 0
for x in q:
if x is not None:
if x.left is not None:
cnt += 1
assert x.left.parent == x
if x.right is not None:
cnt += 1
assert x.right.parent == x
new_q.append(x.left)
new_q.append(x.right)
else:
new_q.append(None)
new_q.append(None)
print(level, [(str(e) if e is not None else '*') for e in q])
level += 1
if cnt == 0:
break
else:
q = new_q
def rotate_right(self):
p = self.parent
assert self.left is not None
n = self.left
self.left, n.right = n.right, self
n.parent = p
self.parent = n
if self.left is not None:
self.left.parent = self
if p is not None:
if p.left == self:
p.left = n
else:
p.right = n
return n
class RBTree(object):
def __init__(self):
self._root = None
def contains(self, val):
n, _ = self._find_node(val)
return n is not None
def insert(self, val):
n, p = self._find_node(val)
if n is not None:
return False
if p is None:
self._root = RBNode(val)
else:
if val < p.val:
p.left = RBNode(val, is_red=True, parent=p)
n = p.left
else:
p.right = RBNode(val, is_red=True, parent=p)
n = p.right
self._fix_insertions(n)
return True
def remove(self, val):
n, _ = self._find_node(val)
if n is None:
return False
if n.right is not None:
t = n.right
while t.left is not None:
t = t.left
n.val = t.val
elif n.left is not None:
t = n.left
while t.right is not None:
t = t. right
n.val = t.val
else:
t = n
self._fix_removals(t)
# remove
child = t.left if t.left is not None else t.right
if child is not None:
child.parent = t.parent
if t.parent is not None:
if t.parent.left == t:
t.parent.left = child
else:
t.parent.right = child
if self._root == t:
self._root = None
return True
def _find_node(self, val):
n = self._root
p = None
while n is not None:
if val < n.val:
p = n
n = n.left
elif val > n.val:
p = n
n = n.right
else:
return (n, p)
return (None, p)
def _fix_insertions(self, n):
# insert red
while n.parent is not None and n.grandparent is not None:
if not (n.is_red and n.parent.is_red):
break;
if n.uncle is not None and n.uncle.is_red:
n.grandparent.is_red = True
n.parent.is_red = False
n.uncle.is_red = False
n = n.grandparent
else:
if n.parent == n.grandparent.left and n == n.parent.right:
n, _ = n.parent, n.parent.rotate_left()
elif n.parent == n.grandparent.right and n == n.parent.left:
n, _ = n.parent, n.parent.rotate_right()
else:
n.parent.is_red = False
n.grandparent.is_red = True
if n.parent == n.grandparent.left:
n = n.grandparent.rotate_right()
else:
n = n.grandparent.rotate_left()
if n.parent is None:
self._root = n
if self._root.is_red:
self._root.is_red = False
def _fix_removals(self, n):
if n.is_red:
return
# removing black
# 1 - short circuit
child = n.left if n.left is not None else n.right
if child is not None and child.is_red:
child.is_red = False
return
while n.parent is not None:
# 2
if n.sibling is not None and n.sibling.is_red:
n.parent.is_red = True
n.sibling.is_red = False
if n == n.parent.left:
tmp = n.parent.rotate_left()
else:
tmp = n.parent.rotate_right()
if tmp.parent is None:
self._root = tmp
# 3/4
if n.sibling is not None and not n.sibling.is_red:
s = n.sibling
if s.black() and s.left_black() and s.right_black():
s.set_red()
if n.parent.black():
n = n.parent
continue
else:
n.parent.set_black()
break
# 5
if n.sibling is not None and not n.sibling.is_red:
s = n.sibling
if s == n.parent.right and (s.left is not None and s.left.red()) and s.right_black():
s.left.set_black()
s.set_red()
s.rotate_right()
elif s == n.parent.left and (s.right is not None and s.right.red()) and s.left_black():
s.right.set_black()
s.set_red()
s.rotate_left()
# 6
s = n.sibling
if s == n.parent.right and (s.right is not None and s.right.red()):
s.is_red = n.parent.is_red
s.right.set_black()
n.parent.set_black()
tmp = n.parent.rotate_left()
if tmp.parent is None:
self._root = tmp
elif s == n.parent.left and (s.left is not None and s.left.red()):
s.is_red = n.parent.is_red
s.left.set_black()
n.parent.set_black()
tmp = n.parent.rotate_right()
if tmp.parent is None:
self._root = tmp
break
def print_me(self):
if self._root is None:
print(['*'])
else:
self._root.print_me()
def check_me(self):
if self._root is None:
return
assert not self._root.is_red
self._root.check_me()
if __name__ == '__main__':
tree = RBTree()
print(tree.contains(3))
import random
nums = list(range(1, 37))
random.shuffle(nums)
for i in nums:
print('# add ', i)
assert tree.insert(i)
tree.print_me()
tree.check_me()
random.shuffle(nums)
for i in nums:
print('# rm ', i)
assert tree.remove(i)
tree.print_me()
tree.check_me()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment