Skip to content

Instantly share code, notes, and snippets.

@natsugiri
Last active August 29, 2015 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save natsugiri/187f327b4eb7ac166731 to your computer and use it in GitHub Desktop.
Save natsugiri/187f327b4eb7ac166731 to your computer and use it in GitHub Desktop.
AVL Tree Set
template<class T> struct AVLTreeSet {
struct Node {
T val;
int s, h;
Node *l, *r;
Node(const T &val=T()): val(val) { s = h = 1; l = r = NULL; }
inline Node*& ch(bool b) { return b? r: l; }
inline void resize() {
s = 1 + size(l) + size(r);
h = 1 + max(height(l), height(r));
}
};
Node *root;
AVLTreeSet(): root(NULL) {}
int size() { return size(root); }
static int size(Node *x) { return x? x->s: 0; }
static int height(Node *x) { return x? x->h: 0; }
static int diff(Node *x) { return x? height(x->l) - height(x->r): 0; }
Node* rot(Node *x, bool b) { // b = raise r?
Node *y = x->ch(b);
x->ch(b) = y->ch(!b); x->resize();
y->ch(!b) = x; y->resize();
return y;
}
Node* balance(Node *x) {
if (diff(x) > 1) {
if (diff(x->l) < 0) { x->l = rot(x->l, 1); }
x = rot(x, 0);
} else if (diff(x) < -1) {
if (diff(x->r) > 0) { x->r = rot(x->r, 0); }
x = rot(x, 1);
}
return x;
}
void insert(const T &v) { root = insert(root, v); }
Node* insert(Node *x, const T &v) {
if (!x) return new Node(v);
if (v == x->val) return x;
bool b = (x->val < v);
x->ch(b) = insert(x->ch(b), v); x->resize();
return balance(x);
}
void erase(const T &v) { root = erase(root, v); }
Node* erase(Node *x, const T &v) {
if (!x) return NULL;
if (v == x->val) {
Node *y = merge(x->l, x->r);
delete x;
return y;
}
bool b = x->val < v;
x->ch(b) = erase(x->ch(b), v); x->resize();
return balance(x);
}
Node* merge(Node *x, Node *y) {
if (!x) return y;
x->r = merge(x->r, y); x->resize();
return balance(x);
}
const T* find(const T &v) {
Node *x = root;
while (x) {
if (x->val == v) return &x->val;
x = x->ch(x->val < v);
}
return NULL;
}
const T at(int n) {
Node *x = root;
while (x) {
if (size(x->l) == n) return x->val;
if (size(x->l) > n) x = x->l;
else { n -= size(x->l)+1; x = x->r; }
}
return T();
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment