Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active September 21, 2022 07:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scturtle/727855532ccb466fcea577526ecdfa33 to your computer and use it in GitHub Desktop.
Save scturtle/727855532ccb466fcea577526ecdfa33 to your computer and use it in GitHub Desktop.
red black tree
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <random>
#include <vector>
enum Color { Black, Red };
enum Leaf { Left, Right };
inline Leaf operator!(Leaf leaf) { return Leaf(1 - leaf); }
struct TreeNode {
uintptr_t m_parent = 0;
TreeNode *m_child[2] = {0, 0};
inline TreeNode *&child(Leaf leaf) { return m_child[leaf]; }
inline TreeNode *&left() { return child(Left); }
inline TreeNode *&right() { return child(Right); }
inline TreeNode *parent() { return (TreeNode *)(m_parent & (~1)); }
inline void set_parent(TreeNode *parent) {
m_parent = (uintptr_t)parent | (m_parent & 1);
}
inline Color color() { return Color(m_parent & 1); }
inline bool is(Color c) { return color() == c; }
inline void set(Color c) { m_parent = (m_parent & (~1)) | uintptr_t(c); }
};
template <typename T, typename U>
inline T *member_to_object(U T::*member, U *ptr) {
return (T *)((char *)ptr - (char *)&((T *)0->*member));
}
template <typename T> inline TreeNode *to_treenode(T *obj) {
return obj ? &obj->m_treenode : nullptr;
}
template <typename T> inline T *from_treenode(TreeNode *node) {
return node ? member_to_object(&T::m_treenode, node) : nullptr;
}
struct RBTreeBase {
using N = TreeNode;
N *root = nullptr;
int size = 0;
N *minmax(Leaf leaf) const {
if (!root)
return nullptr;
N *r = root;
while (r->child(leaf))
r = r->child(leaf);
return r;
}
N *step(N *x, Leaf leaf) const {
if (x->child(leaf)) {
x = x->child(leaf);
while (x->child(!leaf))
x = x->child(!leaf);
} else {
N *p = x->parent();
while (p && x == p->child(leaf)) {
x = p;
p = p->parent();
}
x = p;
}
return x;
}
template <typename Lambda>
N *&navigate(N *&x, N *&p, Leaf leaf_on_equal, Lambda lambda) {
x = p = nullptr;
N **r = &root;
while (*r) {
int direction = lambda(*r);
p = *r;
if (direction < 0)
r = &(*r)->left();
else if (direction > 0)
r = &(*r)->right();
else {
x = *r;
r = &(*r)->child(leaf_on_equal);
}
}
return *r;
}
void rotate(N *x, Leaf leaf) {
N *c = x->child(leaf);
N *cc = c->child(!leaf);
x->child(leaf) = cc;
if (cc)
cc->set_parent(x);
c->child(!leaf) = x;
N *p = x->parent();
if (p) {
if (p->left() == x)
p->left() = c;
else
p->right() = c;
} else
root = c;
x->set_parent(c);
c->set_parent(p);
}
void insert(N *x, N *p) {
size += 1;
x->left() = x->right() = nullptr;
x->set_parent(p);
x->set(Red);
while ((p = x->parent()) && p->is(Red)) {
N *g = p->parent(); // red always has parent
Leaf leaf = p == g->left() ? Right : Left;
N *u = g->child(leaf);
// if uncle is red, flip all color and up
if (u && u->is(Red)) {
p->set(Black);
u->set(Black);
g->set(Red);
x = g; // continue
} else {
// g->p->x not in the same direction
if (p->child(leaf) == x) {
rotate(p, leaf);
std::swap(p, x);
}
// p red, g black, u black, g->p->x same direction
rotate(g, !leaf);
g->set(Red);
p->set(Black); // break
}
}
root->set(Black);
}
void fixup_remove(N *x, N *p) {
// x (may be null) has extra black
while (root != x && (x == nullptr || x->is(Black))) {
Leaf leaf = x == p->left() ? Right : Left;
N *s = p->child(leaf);
// if sibling is red, make it black
if (s->is(Red)) {
s->set(Black);
p->set(Red);
rotate(p, leaf);
s = p->child(leaf);
}
Color clr[2] = {s->child(Left) ? s->child(Left)->color() : Black,
s->child(Right) ? s->child(Right)->color() : Black};
// if both children is black, move black up and continue
if (clr[Left] == Black && clr[Right] == Black) {
s->set(Red);
x = p;
p = x->parent(); // continue
} else {
// make sibling's same direction child red
if (clr[leaf] == Black) {
s->child(!leaf)->set(Black);
s->set(Red);
rotate(s, !leaf);
s = p->child(leaf);
}
// rotate sibling up but keep parent's color
s->set(p->color());
p->set(Black);
s->child(leaf)->set(Black);
rotate(p, leaf);
x = root; // done, new root's color may need fixed
}
}
if (x)
x->set(Black);
}
void remove(N *x) {
size -= 1;
Color color;
N *child, *parent;
if (x->left() && x->right()) {
N *old = x;
x = x->right();
while (x->left())
x = x->left();
// x's original state
color = x->color();
child = x->right();
parent = x->parent();
// move x to old's place
N *p = old->parent();
if (p) {
if (p->left() == old)
p->left() = x;
else
p->right() = x;
} else
root = x;
x->set_parent(p);
x->set(old->color());
x->left() = old->left();
x->left()->set_parent(x);
if (parent != old) {
x->right() = old->right();
x->right()->set_parent(x);
}
// remove x from it's original place
if (parent == old) {
parent = x;
} else {
parent->left() = child;
if (child)
child->set_parent(parent);
}
} else {
color = x->color();
child = x->left() ? x->left() : x->right();
parent = x->parent();
if (child)
child->set_parent(parent);
if (parent) {
if (parent->left() == x)
parent->left() = child;
else
parent->right() = child;
} else
root = child;
}
if (color == Black)
fixup_remove(child, parent);
}
void check() {
if (root)
assert(root->is(Black));
int cur = 0, max = -1;
std::function<void(N *)> dfs = [&](N *n) {
if (n && n->parent())
assert(n->is(Black) || n->parent()->is(Black));
if (!n || n->is(Black))
cur += 1;
if (n == nullptr) {
if (max == -1)
max = cur;
assert(cur == max);
} else {
dfs(n->left());
dfs(n->right());
}
if (!n || n->is(Black))
cur -= 1;
};
dfs(root);
}
};
enum SearchMode { Exact, LowerBound, UpperBound };
template <typename T> struct RBTree {
using N = TreeNode;
RBTreeBase tree;
int size() { return tree.size; }
T *first() const { return from_treenode<T>(tree.minmax(Left)); }
T *last() const { return from_treenode<T>(tree.minmax(Right)); }
T *prev(T *obj) const {
return from_treenode<T>(tree.step(to_treenode(obj), Left));
}
T *next(T *obj) const {
return from_treenode<T>(tree.step(to_treenode(obj), Right));
}
template <typename Compare> T *search(SearchMode mode, Compare compare) {
N *x, *p;
auto cmp = [&compare](N *t) { return compare(from_treenode<T>(t)); };
N *&pos = tree.navigate(x, p, mode == UpperBound ? Right : Left, cmp);
if (mode != Exact) {
if (mode == UpperBound && x)
x = tree.step(x, Right);
else if (!x && p) {
x = &p->left() == &pos ? p : tree.step(p, Right);
}
}
return from_treenode<T>(x);
}
T *find(T *obj, SearchMode mode = Exact) {
return search(mode, [obj](T *cur) {
return *obj < *cur ? -1 : *cur < *obj ? 1 : 0;
});
}
T *insert(T *obj, bool allow_dups = false) {
auto cmp = [obj](N *n) {
T *cur = from_treenode<T>(n);
return *obj < *cur ? -1 : *cur < *obj ? 1 : 0;
};
N *x, *p;
N *&pos = tree.navigate(x, p, Right, cmp);
if (x && !allow_dups)
return from_treenode<T>(x);
pos = to_treenode(obj);
tree.insert(pos, p);
return obj;
}
void remove(T *obj) { tree.remove(to_treenode(obj)); }
};
struct MyInt {
int m_val;
TreeNode m_treenode;
MyInt(int val) : m_val{val} {}
bool operator<(const MyInt &rhs) const { return m_val < rhs.m_val; }
};
int main() {
std::random_device rd;
std::default_random_engine rng(rd());
// std::default_random_engine rng(42);
RBTree<MyInt> t;
std::vector<int> v;
for (int i = 1; i <= 1000; ++i)
for (int j = 0; j < 10; ++j)
v.push_back(i);
std::shuffle(v.begin(), v.end(), rng);
for (int i = 0; i < v.size(); ++i) {
t.insert(new MyInt(v[i]), true);
t.tree.check();
}
int cnt = 0;
for (MyInt *it = t.first(); it; it = t.next(it))
++cnt;
assert(cnt == v.size());
std::shuffle(v.begin(), v.end(), rng);
for (int i = 0; i < v.size(); ++i) {
MyInt tmp(v[i]);
MyInt *found = t.find(&tmp);
assert(found);
t.remove(found);
delete found;
t.tree.check();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment