Skip to content

Instantly share code, notes, and snippets.

@upsuper
Last active October 3, 2021 21:28
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save upsuper/6332576 to your computer and use it in GitHub Desktop.
Save upsuper/6332576 to your computer and use it in GitHub Desktop.
A Red-Black Tree implemented in C++ (need C++11)
#ifndef RBTREE_RBTREE_H_
#define RBTREE_RBTREE_H_
#include <cstddef>
#include <cassert>
#include <utility>
namespace upsuper {
namespace learning {
// A macro to disallow the copy constructor and operator= functions
// This should be used in the private: declarations for a class
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \
void operator=(const TypeName&)
template <class Key>
class RBTree {
public:
typedef std::size_t size_type;
inline RBTree() : nil_(new Node), count_(0) {
root_ = nil_;
nil_->parent = nil_;
nil_->left = nil_;
nil_->right = nil_;
nil_->color = kBlack;
}
inline ~RBTree() {
FreeSubtree(root_);
delete nil_;
}
bool Put(const Key& key);
bool Remove(const Key& key);
inline bool Contains(const Key& key) const {
Node *node = FindNodeOrParent(key);
return !IsNil(node) && node->key == key;
}
inline size_type Count() const {
return count_;
}
inline bool Empty() const {
return count_ == 0;
}
protected:
enum Color { kRed, kBlack };
struct Node {
Node *parent;
Node *left;
Node *right;
Color color;
Key key;
};
inline const Node *GetRoot() const {
return root_;
}
inline bool IsNil(const Node *node) const {
return node == nil_;
}
inline bool IsRed(const Node *node) const {
return node->color == kRed;
}
inline bool IsBlack(const Node *node) const {
return node->color == kBlack;
}
private:
inline void SetRed(Node *node) {
assert(node != nil_);
node->color = kRed;
}
inline void SetBlack(Node *node) {
node->color = kBlack;
}
inline bool IsLeftChild(const Node *node) const {
return node->parent->left == node;
}
inline bool IsRightChild(const Node *node) const {
return node->parent->right == node;
}
inline void SetLeft(Node *node, Node *child) {
assert(!IsNil(node));
node->left = child;
if (!IsNil(child))
child->parent = node;
}
inline void SetRight(Node *node, Node *child) {
assert(!IsNil(node));
node->right = child;
if (!IsNil(child))
child->parent = node;
}
inline Node *GetSibling(const Node *node) const {
if (IsLeftChild(node))
return node->parent->right;
else if (IsRightChild(node))
return node->parent->left;
assert(false);
}
inline Node *ReplaceChild(Node *child, Node *new_child) {
if (IsNil(child->parent)) {
root_ = new_child;
new_child->parent = nil_;
} else if (IsLeftChild(child)) {
SetLeft(child->parent, new_child);
} else if (IsRightChild(child)) {
SetRight(child->parent, new_child);
} else { assert(false); }
return new_child;
}
inline Node *LeftRotate(Node *node) {
assert(node != nil_ && node->right != nil_);
Node *child = node->right;
ReplaceChild(node, child);
SetRight(node, child->left);
SetLeft(child, node);
std::swap(node->color, child->color);
return child;
}
inline Node *RightRotate(Node *node) {
assert(node != nil_ && node->left != nil_);
Node *child = node->left;
ReplaceChild(node, child);
SetLeft(node, child->right);
SetRight(child, node);
std::swap(node->color, child->color);
return child;
}
inline Node *ReverseRotate(Node *node) {
if (IsLeftChild(node))
return RightRotate(node->parent);
else if (IsRightChild(node))
return LeftRotate(node->parent);
assert(false);
}
inline Node *FindNodeOrParent(const Key& key) const {
Node *node = root_;
Node *parent = nil_;
while (!IsNil(node)) {
if (node->key == key) return node;
parent = node;
node = node->key > key ? node->left : node->right;
}
return parent;
}
void FixInsert(Node *node);
void FixRemove(Node *node);
void FreeSubtree(Node *root) {
if (root != nil_) {
FreeSubtree(root->left);
FreeSubtree(root->right);
delete root;
}
}
Node *root_;
Node *nil_;
size_type count_;
DISALLOW_COPY_AND_ASSIGN(RBTree<Key>);
};
/* Public */
template <class Key>
bool RBTree<Key>::Put(const Key& key) {
Node *parent = FindNodeOrParent(key);
if (!IsNil(parent) && parent->key == key)
return false;
Node *node = new Node{nil_, nil_, nil_, kRed, key};
if (IsNil(parent)) {
root_ = node;
} else { // !IsNil(parent)
if (key < parent->key)
SetLeft(parent, node);
else
SetRight(parent, node);
}
FixInsert(node);
++count_;
return true;
}
template <class Key>
bool RBTree<Key>::Remove(const Key& key) {
Node *node = FindNodeOrParent(key);
Node *child;
if (IsNil(node) || node->key != key)
return false;
if (IsNil(node->right)) {
child = node->left;
} else if (IsNil(node->left)) {
child = node->right;
} else {
Node *sub = node->right;
while (!IsNil(sub->left))
sub = sub->left;
node->key = std::move(sub->key);
node = sub;
child = sub->right;
}
child = IsNil(child) ? node : ReplaceChild(node, child);
if (IsBlack(node))
FixRemove(child);
if (node == child)
ReplaceChild(node, nil_);
delete node;
--count_;
return true;
}
/* Private */
template <class Key>
void RBTree<Key>::FixInsert(Node *node) {
while (!IsBlack(node) && !IsBlack(node->parent)) {
Node *parent = node->parent;
Node *uncle = GetSibling(parent);
if (IsRed(uncle)) {
SetBlack(uncle);
SetBlack(parent);
SetRed(parent->parent);
node = parent->parent;
} else { // IsBlack(uncle)
if (IsLeftChild(node) != IsLeftChild(parent))
parent = ReverseRotate(node);
node = ReverseRotate(parent);
}
}
if (IsNil(node->parent))
SetBlack(node);
}
template <class Key>
void RBTree<Key>::FixRemove(Node *node) {
while (!IsRed(node) && !IsNil(node->parent)) {
Node *sibling = GetSibling(node);
if (IsRed(sibling)) {
ReverseRotate(sibling);
sibling = GetSibling(node);
}
if (IsBlack(sibling->left) && IsBlack(sibling->right)) {
SetRed(sibling);
node = node->parent;
} else {
if (IsLeftChild(sibling) && !IsRed(sibling->left))
sibling = LeftRotate(sibling);
else if (IsRightChild(sibling) && !IsRed(sibling->right))
sibling = RightRotate(sibling);
ReverseRotate(sibling);
node = GetSibling(node->parent);
}
}
SetBlack(node);
}
} // namespace learning
} // namespace upsuper
#endif // RBTREE_RBTREE_H_
#include <set>
#include <random>
#include <iostream>
#include "rbtree_tester.hpp"
using std::cout;
using std::cerr;
using std::ends;
using std::endl;
using upsuper::learning::RBTreeTester;
template <class Key>
void __attribute__ ((noreturn)) PrintTreeAndExit(
const std::vector<std::string>& orig_tree, const RBTreeTester<Key>& tree) {
for (auto line : orig_tree)
cout << line << endl;
auto lines = tree.PrintTree();
for (auto line : lines)
cout << line << endl;
exit(1);
}
void SequenceInsert(RBTreeTester<int>& tree, const int n) {
assert(tree.Empty());
for (int i = 0; i < n; ++i) {
auto orig_tree = tree.PrintTree();
tree.Put(i);
if (tree.Count() != i + 1 || !tree.Contains(i) || !tree.Verify()) {
cout << "SequenceInsert: " << i << endl;
PrintTreeAndExit(orig_tree, tree);
}
}
}
void SequenceRemove(RBTreeTester<int>& tree, const int n) {
assert(tree.Count() == n);
for (int i = 0; i < n; ++i) {
auto orig_tree = tree.PrintTree();
tree.Remove(i);
if (tree.Count() != n - i - 1 || tree.Contains(i) || !tree.Verify()) {
cout << "SequenceInsert: " << i << endl;
PrintTreeAndExit(orig_tree, tree);
}
}
}
void ReverseInsert(RBTreeTester<int>& tree, const int n) {
assert(tree.Empty());
for (int i = n - 1; i >= 0; --i) {
auto orig_tree = tree.PrintTree();
tree.Put(i);
if (tree.Count() != n - i || !tree.Contains(i) || !tree.Verify()) {
cout << "ReverseInsert: " << i << endl;
PrintTreeAndExit(orig_tree, tree);
}
}
}
void ReverseRemove(RBTreeTester<int>& tree, const int n) {
assert(tree.Count() == n);
for (int i = n - 1; i >= 0; --i) {
auto orig_tree = tree.PrintTree();
tree.Remove(i);
if (tree.Count() != i || tree.Contains(i) || !tree.Verify()) {
cout << "ReverseRemove: " << i << endl;
PrintTreeAndExit(orig_tree, tree);
}
}
}
void RandomOperations(RBTreeTester<int>& tree, const int n) {
assert(tree.Empty());
std::set<int> ref;
std::random_device rd;
auto seed = rd();
std::mt19937 gen(seed);
std::bernoulli_distribution op_dist(0.8);
std::uniform_int_distribution<> val_dist(0, n - 1);
for (int i = 0; i < n * 5; ++i) {
auto orig_tree = tree.PrintTree();
bool add_item = op_dist(gen);
int val = val_dist(gen);
if (add_item) {
ref.insert(val);
tree.Put(val);
} else {
ref.erase(val);
tree.Remove(val);
}
if (tree.Count() != ref.size() || !tree.Verify()) {
cout << "(Seed: " << seed << ") " <<
(add_item ? "Add " : "Remove ") << val << endl;
PrintTreeAndExit(orig_tree, tree);
}
}
}
int main() {
RBTreeTester<int> tree;
const int n = 1000;
cout << "SequenceInsert & SequenceRemove" << endl;
SequenceInsert(tree, n);
SequenceRemove(tree, n);
cout << "SequenceInsert & ReverseRemove" << endl;
SequenceInsert(tree, n);
ReverseRemove(tree, n);
cout << "ReverseInsert & SequenceRemove" << endl;
ReverseInsert(tree, n);
SequenceRemove(tree, n);
cout << "ReverseInsert & ReverseRemove" << endl;
ReverseInsert(tree, n);
ReverseRemove(tree, n);
cout << "RandomOperations" << endl;
RandomOperations(tree, n);
return 0;
}
#ifndef RBTREE_RBTREE_TESTER_H_
#define RBTREE_RBTREE_TESTER_H_
#include <string>
#include <vector>
#include "rbtree.hpp"
namespace upsuper {
namespace learning {
template <class Key>
class RBTreeTester : public RBTree<Key> {
public:
using size_type = typename RBTree<Key>::size_type;
using RBTree<Key>::Count;
bool Verify() const;
inline std::vector<std::string> PrintTree() const {
std::vector<std::string> tree;
if (!IsNil(GetRoot())) {
tree.push_back("");
BuildPrintTree(GetRoot(), tree);
}
return std::move(tree);
}
private:
using Node = typename RBTree<Key>::Node;
using RBTree<Key>::GetRoot;
using RBTree<Key>::IsNil;
using RBTree<Key>::IsRed;
using RBTree<Key>::IsBlack;
int Traverse(const Node *node, const Key *min, const Key *max,
size_type *count) const;
void BuildPrintTree(const Node *node, std::vector<std::string>& tree) const;
};
/* Public */
template <class Key>
bool RBTreeTester<Key>::Verify() const {
const Node *root = GetRoot();
if (!IsBlack(root) || !IsNil(root->parent))
return false;
const Node *nil = root->parent;
if (!IsBlack(nil) || !IsNil(nil->parent) ||
!IsNil(nil->left) || !IsNil(nil->right))
return false;
size_type count = 0;
size_type bh = Traverse(root, nullptr, nullptr, &count);
if (bh == -1) return false;
if (count != Count()) return false;
return true;
}
/* Private */
template <class Key>
int RBTreeTester<Key>::Traverse(const Node *node,
const Key *min, const Key *max,
size_type *count) const {
if (IsNil(node)) return 0;
if (min != nullptr && node->key <= *min) return -1;
if (max != nullptr && node->key >= *max) return -1;
if (IsRed(node))
if (!IsBlack(node->left) || !IsBlack(node->right))
return -1;
int left_bh = Traverse(node->left, min, &node->key, count);
int right_bh = Traverse(node->right, &node->key, max, count);
if (left_bh == -1 || right_bh == -1 || left_bh != right_bh)
return -1;
++*count;
return IsBlack(node) ? left_bh + 1 : left_bh;
}
template<class Key>
void RBTreeTester<Key>::BuildPrintTree(
const Node *node, std::vector<std::string>& tree) const {
auto& line = tree.back();
line.append(" ").append(std::to_string(node->key))
.append(IsRed(node) ? "(R)" : "(B)").append(" ");
auto len = line.size();
if (!IsNil(node->left)) {
line.append("-");
BuildPrintTree(node->left, tree);
}
if (!IsNil(node->right)) {
for (auto iter = tree.rbegin(); (*iter)[len] == ' '; ++iter)
(*iter)[len] = '|';
std::string line2(len, ' ');
line2.append("\\");
tree.push_back(line2);
BuildPrintTree(node->right, tree);
}
}
} // namespace learning
} // namespace upsuper
#endif // RBTREE_RBTREE_TESTER_H_
@yujianyuanhaha
Copy link

rbtree.hpp:172:24: error: expected ';' at end of declaration
Node *node = new Node{nil_, nil_, nil_, kRed, key};
^
;
1 error generated.

@nhnguyen99
Copy link

why are the copy constructor and assignment operator disallowed in this code?

@upsuper
Copy link
Author

upsuper commented Apr 1, 2020

why are the copy constructor and assignment operator disallowed in this code?

I can't recall. It could be either that I was just too lazy to implement them in the proper way, or I thought they can be very misleading to use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment