Skip to content

Instantly share code, notes, and snippets.

@natsugiri
Last active August 29, 2015 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save natsugiri/531cda2b5382b2ba8aa3 to your computer and use it in GitHub Desktop.
Save natsugiri/531cda2b5382b2ba8aa3 to your computer and use it in GitHub Desktop.
RBST Set
template<class T> struct RBSTSet {
struct Node {
T val;
int s;
Node *l, *r;
Node(const T &val=T()): val(val), s(1) { l = r = NULL; }
inline Node*& ch(bool b) { return b? r: l; }
inline void resize() { s = 1 + size(l) + size(r); }
};
Node *root;
RBSTSet(): root(NULL) {}
int size() { return size(root); }
static int size(Node *x) { return x? x->s: 0; }
Node* rot(Node *x, bool b) { // b = raise r? rotl?
Node *y = x->ch(b);
x->ch(b) = y->ch(!b); x->resize();
y->ch(!b) = x; y->resize();
return y;
}
Node* insert_root(Node *x, const T &v) {
if (!x) return new Node(v);
if (x->val == v) return x;
bool b = (x->val < v);
x->ch(b) = insert_root(x->ch(b), v); x->resize();
x = rot(x, b);
return x;
}
void insert(const T &v) { root = insert(root, v); }
Node* insert(Node *x, const T &v) {
if (!x) return new Node(v);
if (v == x->val) return x;
if (rand() % (size(x) + 1) == 0) {
x = insert_root(x, v);
} else {
bool b = (x->val < v);
x->ch(b) = insert(x->ch(b), v); x->resize();
}
return x;
}
Node* erase_first(Node *x, T &v) {
if (!x->l) {
v = x->val;
Node *y = x->r;
delete x; return y;
}
x->l = erase_first(x->l, v); x->resize();
return x;
}
void erase(const T &v) { root = erase(root, v); }
Node* erase(Node *x, const T &v) {
if (!x) return NULL;
if (v == x->val) {
if (!x->r) {
Node *y = x->l;
delete x; return y;
}
x->r = erase_first(x->r, x->val); x->resize();
return x;
}
bool b = (x->val < v);
x->ch(b) = erase(x->ch(b), v); x->resize();
return x;
}
const T* find(const T &v) {
Node *x = root;
while (x) {
if (x->val == v) return &x->val;
x = x->ch(x->val < v);
}
return NULL;
}
const T at(int n) {
Node *x = root;
while (x) {
if (size(x->l) == n) return x->val;
if (size(x->l) > n) x = x->l;
else { n -= size(x->l)+1; x = x->r; }
}
return T();
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment