Skip to content

Instantly share code, notes, and snippets.

@kkew3
Created December 6, 2021 03:43
Show Gist options
  • Save kkew3/56097f7b194eb3d9d1b3ce0d1ef64be6 to your computer and use it in GitHub Desktop.
Save kkew3/56097f7b194eb3d9d1b3ce0d1ef64be6 to your computer and use it in GitHub Desktop.
The book "Introduction to Algorithms Third Edition" presents red-black tree operations with nodes including parent pointer `.p`. This gist otherwise gives implementation that does not require the parent pointer in C++.
#include "rbtree.h"
rbnode *left_rotate(rbnode *x)
{
// assert(x && x->right);
rbnode *y = x->right;
x->right = y->left;
y->left = x;
return y;
}
rbnode *right_rotate(rbnode *y)
{
// assert(y && y->left);
rbnode *x = y->left;
y->left = x->right;
x->right = y;
return x;
}
rbnode *copy_tree(rbnode *root)
{
if (!root) {
return nullptr;
}
rbnode *root_copied = new rbnode(root->key, root->color);
std::vector<rbnode *> stack {root}, stack_copied {root_copied};
while (!stack.empty()) {
rbnode *curr = stack.back();
stack.pop_back();
rbnode *curr_copied = stack_copied.back();
stack_copied.pop_back();
if (curr->right) {
stack.push_back(curr->right);
curr_copied->right = new rbnode(
curr->right->key, curr->right->color);
stack_copied.push_back(curr_copied->right);
}
if (curr->left) {
stack.push_back(curr->left);
curr_copied->left = new rbnode(
curr->left->key, curr->left->color);
stack_copied.push_back(curr_copied->left);
}
}
return root_copied;
}
void free_tree(rbnode *root)
{
std::vector<rbnode *> stack;
rbnode *prev = nullptr;
while (root || !stack.empty()) {
if (root) {
stack.push_back(root);
root = root->left;
} else {
root = stack.back();
if (!root->right || root->right == prev) {
prev = root;
delete root;
root = nullptr;
stack.pop_back();
} else {
root = root->right;
}
}
}
}
bool trees_equal(rbnode *root1, rbnode *root2)
{
if (!root1 && !root2) return true;
if (!root1 || !root2) return false;
std::vector<rbnode *> stack1 {root1}, stack2 {root2};
while (!stack1.empty()) {
rbnode *curr1 = stack1.back(), *curr2 = stack2.back();
stack1.pop_back();
stack2.pop_back();
if (curr1->key != curr2->key) return false;
if (curr1->right && curr2->right) {
stack1.push_back(curr1->right);
stack2.push_back(curr2->right);
} else if (curr1->right || curr2->right) {
return false;
}
if (curr1->left && curr2->left) {
stack1.push_back(curr1->left);
stack2.push_back(curr2->left);
} else if (curr1->left || curr2->left) {
return false;
}
}
return true;
}
rbtree &rbtree::operator=(const rbtree &other)
{
rbnode *new_root = copy_tree(other.root);
free_tree(root);
root = new_root;
return *this;
}
bool operator==(const rbtree &t1, const rbtree &t2)
{
return trees_equal(t1.root, t2.root);
}
bool operator!=(const rbtree &t1, const rbtree &t2)
{
return !trees_equal(t1.root, t2.root);
}
bool rbtree::contains(const int key) const
{
rbnode *z = root;
while (z && z->key != key) {
if (key < z->key) z = z->left;
else z = z->right;
}
return z ? true : false;
}
void rbtree::insert(const int key)
{
std::vector<rbnode *> p; // ancestors stack
rbnode *x = root; // the spot to insert
while (x) {
p.push_back(x);
if (key < x->key) x = x->left;
else x = x->right;
}
if (p.empty()) {
// the node to insert is root.
// note that the root is black.
root = new rbnode(key, rbcolor::black);
return;
}
rbnode *z = new rbnode(key);
if (key < p.back()->key) p.back()->left = z;
else p.back()->right = z;
// FIXUP
while (!p.empty() && p.back()->color == rbcolor::red) {
// If p.back() exists and p.back()->color is red, p.back() won't point
// to the root, since root->color must be black. Thus, p.size() must be
// at least two.
if (p.back() == p[p.size() - 2]->left) {
rbnode *uncle = p[p.size() - 2]->right;
if (uncle && uncle->color == rbcolor::red) {
p.back()->color = rbcolor::black;
uncle->color = rbcolor::black;
p[p.size() - 2]->color = rbcolor::red;
z = p[p.size() - 2];
p.erase(p.end() - 2, p.end());
} else {
if (z == p.back()->right) {
p[p.size() - 2]->left = left_rotate(p.back());
p.back() = z;
z = z->left;
}
p.back()->color = rbcolor::black;
p[p.size() - 2]->color = rbcolor::red;
rbnode *h = right_rotate(p[p.size() - 2]);
if (p.size() >= 3) {
if (p[p.size() - 2] == p[p.size() - 3]->left) {
p[p.size() - 3]->left = h;
} else {
p[p.size() - 3]->right = h;
}
} else {
// p[p.size() - 2] is root
root = h;
}
p.erase(p.end() - 2);
}
} else { // symmetric code of above
rbnode *uncle = p[p.size() - 2]->left;
if (uncle && uncle->color == rbcolor::red) {
p.back()->color = rbcolor::black;
uncle->color = rbcolor::black;
p[p.size() - 2]->color = rbcolor::red;
z = p[p.size() - 2];
p.erase(p.end() - 2, p.end());
} else {
if (z == p.back()->left) {
p[p.size() - 2]->right = right_rotate(p.back());
p.back() = z;
z = z->right;
}
p.back()->color = rbcolor::black;
p[p.size() - 2]->color = rbcolor::red;
rbnode *h = left_rotate(p[p.size() - 2]);
if (p.size() >= 3) {
if (p[p.size() - 2] == p[p.size() - 3]->left) {
p[p.size() - 3]->left = h;
} else {
p[p.size() - 3]->right = h;
}
} else {
// p[p.size() - 2] is root
root = h;
}
p.erase(p.end() - 2);
}
}
}
root->color = rbcolor::black;
}
bool rbtree::erase(const int key)
{
std::vector<rbnode *> p; // ancestors stack
// search for the node z to delete, and meanwhile record the ancestors of z
rbnode *z = root;
while (z && z->key != key) {
p.push_back(z);
if (key < z->key) z = z->left;
else z = z->right;
}
// if z is not found (including the case where root is null)
if (!z) {
return false;
}
rbcolor y_original_color = z->color;
rbnode *x;
if (!z->left) {
x = z->right; // could be null
// now transplant x inplace of z
// if z is root:
if (p.empty()) {
root = x;
} else {
// p.back() is the parent of z
if (p.back()->left == z) p.back()->left = x;
else p.back()->right = x;
}
// now p stores the ancestors of x (exluding x)
} else if (!z->right) {
x = z->left; // could be null
// transplant x inplace of z
// if z is root
if (p.empty()) {
root = x;
} else {
// p.back() is the parent of z
if (p.back()->left == z) p.back()->left = x;
else p.back()->right = x;
}
// now p stores the ancestors of x (excluding x)
} else {
// find the successor y of z, updating the ancestors stack p
rbnode *y = z->right;
p.push_back(z);
while (y->left) {
p.push_back(y);
y = y->left;
}
// now p stores the ancestors of y (excluding y)
y_original_color = y->color;
// if used the predecessor y of z, would let `x = y->left`
x = y->right; // could be null
// p.back() must exist since `p.push_back(z)` above;
// if parent of y is not z:
if (p.back() != z) {
// since y is the successor of z, the following assert
// must be true:
// assert(p.back()->left == y);
p.back()->left = x;
y->right = z->right;
}
// prepare to transplant y inplace of z
decltype(p.size()) z_index = p.size() - 1; // index of z in p
while (p[z_index] != z) --z_index;
// transplant y inplace of z
// if z is root:
if (z_index == 0) {
root = y;
} else {
decltype(z_index) zp_index = z_index - 1; // index of z's parent
if (p[zp_index]->left == z) p[zp_index]->left = y;
else p[zp_index]->right = y;
}
p[z_index] = y;
// now p stores the ancestors of x (excluding x)
y->left = z->left;
y->color = z->color;
}
delete z;
if (y_original_color == rbcolor::black) {
// FIXUP
// while x != root and x->color == black
while (!p.empty() && (!x || x->color == rbcolor::black)) {
if ((x && x == p.back()->left) || (!x && p.back()->right)) {
rbnode *w = p.back()->right; // x's sibling, can't be null
// case 1
if (w->color == rbcolor::red) {
w->color = rbcolor::black;
p.back()->color = rbcolor::red;
left_rotate(p.back()); // the return value will be w
if (p.size() >= 2) {
if (p[p.size() - 2]->left == p.back()) {
p[p.size() - 2]->left = w;
} else {
p[p.size() - 2]->right = w;
}
} else {
root = w;
}
p.insert(p.end() - 1, w);
// `left_rotate` has changed p.back()'s `right` pointer;
// thus, the following statement is nontrivial
w = p.back()->right;
}
// case 2
if ((!w->left || w->left->color == rbcolor::black)
&& (!w->right || w->right->color == rbcolor::black)) {
w->color = rbcolor::red;
x = p.back();
p.pop_back();
} else { // case 3
if (!w->right || w->right->color == rbcolor::black) {
// w->left->color is red; thus w->left is not null
w->left->color = rbcolor::black;
w->color = rbcolor::red;
p.back()->right = right_rotate(w);
w = p.back()->right;
}
// case 4
w->color = p.back()->color;
p.back()->color = rbcolor::black;
// after case 3 w must have a red right child
w->right->color = rbcolor::black;
rbnode *h = left_rotate(p.back());
if (p.size() >= 2) {
if (p[p.size() - 2]->left == p.back()) {
p[p.size() - 2]->left = h;
} else {
p[p.size() - 2]->right = h;
}
} else {
root = h;
}
// for clarity, maintain the ancestors stack even if it's
// to be cleared immediately
p.insert(p.end() - 1, h);
x = p.front();
p.clear();
}
} else { // symmetric code of above
rbnode *w = p.back()->left;
if (w->color == rbcolor::red) {
w->color = rbcolor::black;
p.back()->color = rbcolor::red;
right_rotate(p.back());
if (p.size() >= 2) {
if (p[p.size() - 2]->left == p.back()) {
p[p.size() - 2]->left = w;
} else {
p[p.size() - 2]->right = w;
}
} else {
root = w;
}
p.insert(p.end() - 1, w);
w = p.back()->left;
}
if ((!w->left || w->left->color == rbcolor::black)
&& (!w->right || w->right->color == rbcolor::black)) {
w->color = rbcolor::red;
x = p.back();
p.pop_back();
} else {
if (!w->left || w->left->color == rbcolor::black) {
w->right->color = rbcolor::black;
w->color = rbcolor::red;
p.back()->left = left_rotate(w);
w = p.back()->left;
}
w->color = p.back()->color;
p.back()->color = rbcolor::black;
w->left->color = rbcolor::black;
rbnode *h = right_rotate(p.back());
if (p.size() >= 2) {
if (p[p.size() - 2]->left == p.back()) {
p[p.size() - 2]->left = h;
} else {
p[p.size() - 2]->right = h;
}
} else {
root = h;
}
p.insert(p.end() - 1, h);
x = p.front();
p.clear();
}
}
}
if (x) {
x->color = rbcolor::black;
}
}
return true;
}
#ifndef _RBTREE_H_
#define _RBTREE_H_
#include <vector>
enum class rbcolor { black, red };
struct rbnode {
int key;
rbnode *left;
rbnode *right;
rbcolor color;
rbnode(int key, rbcolor color = rbcolor::red):
key(key), left(nullptr), right(nullptr), color(color)
{}
};
/// x and x->right must not be nullptr; otherwise undefined
rbnode *left_rotate(rbnode *x);
/// y and y->left must not be nullptr; otherwise undefined
rbnode *right_rotate(rbnode *y);
/// Copy a tree.
/// @param root the root of tree from which to copy
/// @return a pointer to the root of the copied tree
rbnode *copy_tree(rbnode *root);
/// Free a tree.
/// @param root the root of tree to free
void free_tree(rbnode *root);
/// Check if two trees are equal.
/// @param root1 the root of the first tree
/// @param root2 the root of the second tree
/// @return true if equal else false
bool trees_equal(rbnode *root1, rbnode *root2);
class rbtree {
friend bool operator==(const rbtree &, const rbtree &);
friend bool operator!=(const rbtree &, const rbtree &);
friend void test_rbtree_insert_erase();
public:
rbtree(): root(nullptr) {}
rbtree(const rbtree &other): root(copy_tree(other.root)) {}
rbtree &operator=(const rbtree &);
~rbtree() { clear(); }
inline bool empty() const { return !root ? true : false; }
bool contains(const int key) const;
void insert(const int key);
/// Erase node with aid of the successor of key.
/// @return return true if a node has been erased else false
bool erase(const int key);
void clear() { free_tree(root); root = nullptr; }
private:
rbnode *root;
};
bool operator==(const rbtree &, const rbtree &);
bool operator!=(const rbtree &, const rbtree &);
#endif // _RBTREE_H_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment