Last active
September 21, 2022 07:53
-
-
Save scturtle/727855532ccb466fcea577526ecdfa33 to your computer and use it in GitHub Desktop.
red black tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cassert> | |
#include <cstdint> | |
#include <cstdio> | |
#include <random> | |
#include <vector> | |
enum Color { Black, Red }; | |
enum Leaf { Left, Right }; | |
inline Leaf operator!(Leaf leaf) { return Leaf(1 - leaf); } | |
struct TreeNode { | |
uintptr_t m_parent = 0; | |
TreeNode *m_child[2] = {0, 0}; | |
inline TreeNode *&child(Leaf leaf) { return m_child[leaf]; } | |
inline TreeNode *&left() { return child(Left); } | |
inline TreeNode *&right() { return child(Right); } | |
inline TreeNode *parent() { return (TreeNode *)(m_parent & (~1)); } | |
inline void set_parent(TreeNode *parent) { | |
m_parent = (uintptr_t)parent | (m_parent & 1); | |
} | |
inline Color color() { return Color(m_parent & 1); } | |
inline bool is(Color c) { return color() == c; } | |
inline void set(Color c) { m_parent = (m_parent & (~1)) | uintptr_t(c); } | |
}; | |
template <typename T, typename U> | |
inline T *member_to_object(U T::*member, U *ptr) { | |
return (T *)((char *)ptr - (char *)&((T *)0->*member)); | |
} | |
template <typename T> inline TreeNode *to_treenode(T *obj) { | |
return obj ? &obj->m_treenode : nullptr; | |
} | |
template <typename T> inline T *from_treenode(TreeNode *node) { | |
return node ? member_to_object(&T::m_treenode, node) : nullptr; | |
} | |
struct RBTreeBase { | |
using N = TreeNode; | |
N *root = nullptr; | |
int size = 0; | |
N *minmax(Leaf leaf) const { | |
if (!root) | |
return nullptr; | |
N *r = root; | |
while (r->child(leaf)) | |
r = r->child(leaf); | |
return r; | |
} | |
N *step(N *x, Leaf leaf) const { | |
if (x->child(leaf)) { | |
x = x->child(leaf); | |
while (x->child(!leaf)) | |
x = x->child(!leaf); | |
} else { | |
N *p = x->parent(); | |
while (p && x == p->child(leaf)) { | |
x = p; | |
p = p->parent(); | |
} | |
x = p; | |
} | |
return x; | |
} | |
template <typename Lambda> | |
N *&navigate(N *&x, N *&p, Leaf leaf_on_equal, Lambda lambda) { | |
x = p = nullptr; | |
N **r = &root; | |
while (*r) { | |
int direction = lambda(*r); | |
p = *r; | |
if (direction < 0) | |
r = &(*r)->left(); | |
else if (direction > 0) | |
r = &(*r)->right(); | |
else { | |
x = *r; | |
r = &(*r)->child(leaf_on_equal); | |
} | |
} | |
return *r; | |
} | |
void rotate(N *x, Leaf leaf) { | |
N *c = x->child(leaf); | |
N *cc = c->child(!leaf); | |
x->child(leaf) = cc; | |
if (cc) | |
cc->set_parent(x); | |
c->child(!leaf) = x; | |
N *p = x->parent(); | |
if (p) { | |
if (p->left() == x) | |
p->left() = c; | |
else | |
p->right() = c; | |
} else | |
root = c; | |
x->set_parent(c); | |
c->set_parent(p); | |
} | |
void insert(N *x, N *p) { | |
size += 1; | |
x->left() = x->right() = nullptr; | |
x->set_parent(p); | |
x->set(Red); | |
while ((p = x->parent()) && p->is(Red)) { | |
N *g = p->parent(); // red always has parent | |
Leaf leaf = p == g->left() ? Right : Left; | |
N *u = g->child(leaf); | |
// if uncle is red, flip all color and up | |
if (u && u->is(Red)) { | |
p->set(Black); | |
u->set(Black); | |
g->set(Red); | |
x = g; // continue | |
} else { | |
// g->p->x not in the same direction | |
if (p->child(leaf) == x) { | |
rotate(p, leaf); | |
std::swap(p, x); | |
} | |
// p red, g black, u black, g->p->x same direction | |
rotate(g, !leaf); | |
g->set(Red); | |
p->set(Black); // break | |
} | |
} | |
root->set(Black); | |
} | |
void fixup_remove(N *x, N *p) { | |
// x (may be null) has extra black | |
while (root != x && (x == nullptr || x->is(Black))) { | |
Leaf leaf = x == p->left() ? Right : Left; | |
N *s = p->child(leaf); | |
// if sibling is red, make it black | |
if (s->is(Red)) { | |
s->set(Black); | |
p->set(Red); | |
rotate(p, leaf); | |
s = p->child(leaf); | |
} | |
Color clr[2] = {s->child(Left) ? s->child(Left)->color() : Black, | |
s->child(Right) ? s->child(Right)->color() : Black}; | |
// if both children is black, move black up and continue | |
if (clr[Left] == Black && clr[Right] == Black) { | |
s->set(Red); | |
x = p; | |
p = x->parent(); // continue | |
} else { | |
// make sibling's same direction child red | |
if (clr[leaf] == Black) { | |
s->child(!leaf)->set(Black); | |
s->set(Red); | |
rotate(s, !leaf); | |
s = p->child(leaf); | |
} | |
// rotate sibling up but keep parent's color | |
s->set(p->color()); | |
p->set(Black); | |
s->child(leaf)->set(Black); | |
rotate(p, leaf); | |
x = root; // done, new root's color may need fixed | |
} | |
} | |
if (x) | |
x->set(Black); | |
} | |
void remove(N *x) { | |
size -= 1; | |
Color color; | |
N *child, *parent; | |
if (x->left() && x->right()) { | |
N *old = x; | |
x = x->right(); | |
while (x->left()) | |
x = x->left(); | |
// x's original state | |
color = x->color(); | |
child = x->right(); | |
parent = x->parent(); | |
// move x to old's place | |
N *p = old->parent(); | |
if (p) { | |
if (p->left() == old) | |
p->left() = x; | |
else | |
p->right() = x; | |
} else | |
root = x; | |
x->set_parent(p); | |
x->set(old->color()); | |
x->left() = old->left(); | |
x->left()->set_parent(x); | |
if (parent != old) { | |
x->right() = old->right(); | |
x->right()->set_parent(x); | |
} | |
// remove x from it's original place | |
if (parent == old) { | |
parent = x; | |
} else { | |
parent->left() = child; | |
if (child) | |
child->set_parent(parent); | |
} | |
} else { | |
color = x->color(); | |
child = x->left() ? x->left() : x->right(); | |
parent = x->parent(); | |
if (child) | |
child->set_parent(parent); | |
if (parent) { | |
if (parent->left() == x) | |
parent->left() = child; | |
else | |
parent->right() = child; | |
} else | |
root = child; | |
} | |
if (color == Black) | |
fixup_remove(child, parent); | |
} | |
void check() { | |
if (root) | |
assert(root->is(Black)); | |
int cur = 0, max = -1; | |
std::function<void(N *)> dfs = [&](N *n) { | |
if (n && n->parent()) | |
assert(n->is(Black) || n->parent()->is(Black)); | |
if (!n || n->is(Black)) | |
cur += 1; | |
if (n == nullptr) { | |
if (max == -1) | |
max = cur; | |
assert(cur == max); | |
} else { | |
dfs(n->left()); | |
dfs(n->right()); | |
} | |
if (!n || n->is(Black)) | |
cur -= 1; | |
}; | |
dfs(root); | |
} | |
}; | |
enum SearchMode { Exact, LowerBound, UpperBound }; | |
template <typename T> struct RBTree { | |
using N = TreeNode; | |
RBTreeBase tree; | |
int size() { return tree.size; } | |
T *first() const { return from_treenode<T>(tree.minmax(Left)); } | |
T *last() const { return from_treenode<T>(tree.minmax(Right)); } | |
T *prev(T *obj) const { | |
return from_treenode<T>(tree.step(to_treenode(obj), Left)); | |
} | |
T *next(T *obj) const { | |
return from_treenode<T>(tree.step(to_treenode(obj), Right)); | |
} | |
template <typename Compare> T *search(SearchMode mode, Compare compare) { | |
N *x, *p; | |
auto cmp = [&compare](N *t) { return compare(from_treenode<T>(t)); }; | |
N *&pos = tree.navigate(x, p, mode == UpperBound ? Right : Left, cmp); | |
if (mode != Exact) { | |
if (mode == UpperBound && x) | |
x = tree.step(x, Right); | |
else if (!x && p) { | |
x = &p->left() == &pos ? p : tree.step(p, Right); | |
} | |
} | |
return from_treenode<T>(x); | |
} | |
T *find(T *obj, SearchMode mode = Exact) { | |
return search(mode, [obj](T *cur) { | |
return *obj < *cur ? -1 : *cur < *obj ? 1 : 0; | |
}); | |
} | |
T *insert(T *obj, bool allow_dups = false) { | |
auto cmp = [obj](N *n) { | |
T *cur = from_treenode<T>(n); | |
return *obj < *cur ? -1 : *cur < *obj ? 1 : 0; | |
}; | |
N *x, *p; | |
N *&pos = tree.navigate(x, p, Right, cmp); | |
if (x && !allow_dups) | |
return from_treenode<T>(x); | |
pos = to_treenode(obj); | |
tree.insert(pos, p); | |
return obj; | |
} | |
void remove(T *obj) { tree.remove(to_treenode(obj)); } | |
}; | |
struct MyInt { | |
int m_val; | |
TreeNode m_treenode; | |
MyInt(int val) : m_val{val} {} | |
bool operator<(const MyInt &rhs) const { return m_val < rhs.m_val; } | |
}; | |
int main() { | |
std::random_device rd; | |
std::default_random_engine rng(rd()); | |
// std::default_random_engine rng(42); | |
RBTree<MyInt> t; | |
std::vector<int> v; | |
for (int i = 1; i <= 1000; ++i) | |
for (int j = 0; j < 10; ++j) | |
v.push_back(i); | |
std::shuffle(v.begin(), v.end(), rng); | |
for (int i = 0; i < v.size(); ++i) { | |
t.insert(new MyInt(v[i]), true); | |
t.tree.check(); | |
} | |
int cnt = 0; | |
for (MyInt *it = t.first(); it; it = t.next(it)) | |
++cnt; | |
assert(cnt == v.size()); | |
std::shuffle(v.begin(), v.end(), rng); | |
for (int i = 0; i < v.size(); ++i) { | |
MyInt tmp(v[i]); | |
MyInt *found = t.find(&tmp); | |
assert(found); | |
t.remove(found); | |
delete found; | |
t.tree.check(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment