Skip to content

Instantly share code, notes, and snippets.

@lstrihic
Forked from zhenghaoz/redblacktree.hpp
Created August 7, 2016 12:31
Show Gist options
  • Save lstrihic/412689882447178f773ce994aaf60afb to your computer and use it in GitHub Desktop.
Save lstrihic/412689882447178f773ce994aaf60afb to your computer and use it in GitHub Desktop.
Red Black Tree
#include <memory>
#include <iostream>
template <typename Key, typename Value>
class RedBlackTree
{
template <typename T> using shared_ptr = std::shared_ptr<T>;
struct Node;
shared_ptr<Node> nil = std::make_shared<Node>();
shared_ptr<Node> root = nil;
// node structure
struct Node {
enum Color { RED, BLACK };
Color _color;
Key _key;
Value _value;
shared_ptr<Node> _parent, _left, _right;
Node(): _color(BLACK) {}
Node(const Key &key,
const Value &value,
const shared_ptr<Node> &left,
const shared_ptr<Node> &right,
const shared_ptr<Node> &parent):
_key(key), _value(value), _color(RED), _left(left), _right(right), _parent(parent) {}
};
// rotation
void leftRotate(shared_ptr<Node> x)
{
shared_ptr<Node> y = x->_right;
// remove y->left to x->right
x->_right = y->_left;
if (y->_left)
y->_left->_parent = x;
// remove y up
y->_parent = x->_parent;
if (x->_parent == nil)
root = y;
else if (x->_parent->_left == x)
x->_parent->_left = y;
else
x->_parent->_right = y;
// remove x down
x->_parent = y;
y->_left = x;
}
void rightRotate(shared_ptr<Node> x)
{
shared_ptr<Node> y = x->_left;
// remove y->right to x->left
x->_left = y->_right;
if (x->_left)
x->_left->_parent = x;
// remove y up
y->_parent = x->_parent;
if (x->_parent == nil)
root = y;
else if (x->_parent->_left == x)
x->_parent->_left = y;
else
x->_parent->_right = y;
// remove x down
x->_parent = y;
y->_right = x;
}
shared_ptr<Node> min(shared_ptr<Node> ptr)
{
shared_ptr<Node> it = ptr;
while (it != nil) {
ptr = it;
it = it->_left;
}
return ptr;
}
// insert
void insertFixup(shared_ptr<Node> ptr)
{
while (ptr->_parent->_color == Node::RED) {
if (ptr->_parent == ptr->_parent->_parent->_left) {
shared_ptr<Node> y = ptr->_parent->_parent->_right;
// case 1
if (y->_color == Node::RED) {
ptr->_parent->_color = Node::BLACK;
y->_color = Node::BLACK;
ptr->_parent->_parent->_color = Node::RED;
ptr = ptr->_parent->_parent;
} else {
// case 2: switch case 2 to case 3
if (ptr == ptr->_parent->_right) {
ptr = ptr->_parent;
leftRotate(ptr);
}
// case 3
ptr->_parent->_color = Node::BLACK;
ptr->_parent->_parent->_color = Node::RED;
rightRotate(ptr->_parent->_parent);
}
} else {
// with 'left' and 'right' exchanged
shared_ptr<Node> y = ptr->_parent->_parent->_left;
if (y->_color == Node::RED) {
ptr->_parent->_color = Node::BLACK;
y->_color = Node::BLACK;
ptr->_parent->_parent->_color = Node::RED;
ptr = ptr->_parent->_parent;
} else {
if (ptr == ptr->_parent->_left) {
ptr = ptr->_parent;
rightRotate(ptr);
}
ptr->_parent->_color = Node::BLACK;
ptr->_parent->_parent->_color = Node::RED;
leftRotate(ptr->_parent->_parent);
}
}
}
root->_color = Node::BLACK;
}
void insert(shared_ptr<Node> nptr)
{
shared_ptr<Node> it(root), p(root);
// find insert position
while (it != nil) {
p = it;
if (nptr->_key < it->_key)
it = it->_left;
else if (nptr->_key > it->_key)
it = it->_right;
else {
// find target key-value
it->_value = nptr->_value;
return;
}
}
// insert
nptr->_parent = p;
if (p == nil)
root = nptr;
else if (nptr->_key < p->_key)
p->_left = nptr;
else
p->_right = nptr;
// fixup
insertFixup(nptr);
}
// find
shared_ptr<Node> find(Key key)
{
shared_ptr<Node> it(root);
while (it != nil) {
if (key < it->_key)
it = it->_left;
else if (key > it->_key)
it = it->_right;
else
return it;
}
return nil;
}
// delete
void transplant(shared_ptr<Node> u, shared_ptr<Node> v)
{
if (u->_parent == nil)
root = v;
else if (u == u->_parent->_left)
u->_parent->_left = v;
else
u->_parent->_right = v;
v->_parent = u->_parent;
}
void deleteFixup(shared_ptr<Node> ptr)
{
while (ptr != root && ptr->_color == Node::BLACK) {
if (ptr == ptr->_parent->_left) {
shared_ptr<Node> w = ptr->_parent->_right;
// case 1
if (w->_color == Node::RED) {
w->_color = Node::BLACK;
ptr->_parent->_color = Node::RED;
leftRotate(ptr->_parent);
w = ptr->_parent->_right;
}
// case 2
if (w->_left->_color == Node::BLACK && w->_right->_color == Node::BLACK) {
w->_color = Node::RED;
ptr = ptr->_parent;
} else {
// case 3
if (w->_right->_color == Node::BLACK) {
w->_left->_color = Node::BLACK;
w->_color = Node::RED;
rightRotate(w);
w = ptr->_parent->_right;
}
// case 4
w->_color = ptr->_parent->_color;
ptr->_parent->_color = Node::BLACK;
w->_right->_color = Node::BLACK;
leftRotate(ptr->_parent);
ptr = root;
}
} else {
// with 'left' and 'right' exchanged
shared_ptr<Node> w = ptr->_parent->_left;
if (w->_color == Node::RED) {
w->_color = ptr->_parent->_color;
ptr->_parent->_color = Node::RED;
rightRotate(ptr->_parent);
w = ptr->_parent->_left;
}
if (w->_left->_color == Node::BLACK && w->_right->_color == Node::BLACK) {
w->_color = Node::RED;
ptr = ptr->_parent;
} else {
if (w->_right->_color == Node::RED) {
w->_color = Node::RED;
w->_right->_color = Node::BLACK;
leftRotate(w);
w = ptr->_parent->_left;
}
w->_color = ptr->_parent->_color;
w->_left->_color = Node::BLACK;
ptr->_parent->_color = Node::BLACK;
rightRotate(ptr->_parent);
ptr = root;
}
}
}
ptr->_color = Node::BLACK;
}
void del(shared_ptr<Node> ptr)
{
shared_ptr<Node> y = ptr, x;
int y_original_color = y->_color;
if (y->_left == nil) {
x = ptr->_right;
transplant(ptr, ptr->_right);
} else if (y->_right == nil) {
x = ptr->_left;
transplant(ptr, ptr->_left);
} else {
y = min(ptr->_right);
y_original_color = y->_color;
x = y->_right;
if (y->_parent == ptr)
x->_parent = y; // change nil->_parent
else {
transplant(y, y->_right);
y->_right = ptr->_right;
y->_right->_parent = y;
}
transplant(ptr, y);
y->_left = ptr->_left;
y->_left->_parent = y;
y->_color = ptr->_color;
}
if (y_original_color == Node::BLACK)
deleteFixup(x);
}
public:
RedBlackTree() {}
Value *get(const Key &key)
{
shared_ptr<Node> it = find(key);
return it == nil ? nullptr : &it->_value;
}
void put(const Key &key, const Value &value)
{
insert(std::make_shared<Node>(key, value, nil, nil, nil));
}
void remove(const Key &key)
{
shared_ptr<Node> it = find(key);
if (it != nil)
del(it);
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment