Skip to content

Instantly share code, notes, and snippets.

@justiceHui
Created January 7, 2024 09:56
Show Gist options
  • Save justiceHui/a8406518891ca0dba9c7518f8224378b to your computer and use it in GitHub Desktop.
Save justiceHui/a8406518891ca0dba9c7518f8224378b to your computer and use it in GitHub Desktop.
#include <bits/stdc++.h>
using namespace std;
template<typename T>
class avl_tree{
struct avl_node{
avl_node *l, *r, *p;
int dep, bf, sz; T key;
avl_node(T key) : l(nullptr), r(nullptr), p(nullptr), dep(0), bf(0), sz(1), key(key) {}
friend int get_depth(avl_node *x){ return x ? x->dep : -1; }
friend int get_size(avl_node *x){ return x ? x->sz : 0; }
bool is_left() const { return this == p->l; }
bool is_root() const { return !p; }
void update(){
dep = max({-1, get_depth(l), get_depth(r)}) + 1;
bf = get_depth(l) - get_depth(r);
sz = 1 + get_size(l) + get_size(r);
}
};
using node_ptr = avl_node*;
node_ptr root;
void rotate(node_ptr x){
if(x->is_left()) x->r && (x->r->p = x->p), x->p->l = x->r, x->r = x->p;
else x->l && (x->l->p = x->p), x->p->r = x->l, x->l = x->p;
if(x->p->is_root()) root = x;
else (x->p->is_left() ? x->p->p->l : x->p->p->r) = x;
auto t = x->p; x->p = t->p; t->p = x;
t->update(); x->update();
}
void iterative_balance(node_ptr x){
while(x != nullptr){
x->update();
node_ptr y = nullptr;
int cnt = 0;
if(x->bf == 2){
if(x->l->bf != -1) y = x->l, cnt = 1;
else y = x->l->r, cnt = 2;
}
else if(x->bf == -2){
if(x->r->bf != 1) y = x->r, cnt = 1;
else y = x->r->l, cnt = 2;
}
x = x->p;
for(int i=0; i<cnt; i++) rotate(y);
}
}
bool internal_insert(const T &key){
if(!root){ root = new avl_node(key); return true; }
node_ptr x = root, *p = nullptr;
while(true){
if(key < x->key){
if(x->l) x = x->l;
else{ p = &x->l; break; }
}
else if(x->key < key){
if(x->r) x = x->r;
else{ p = &x->r; break; }
}
else return false;
}
*p = new avl_node(key);
(*p)->p = x;
iterative_balance(x);
return true;
}
avl_node* internal_find(const T &key){
auto x = root;
while(true){
if(!x) return nullptr;
if(key < x->key) x = x->l;
else if(x->key < key) x = x->r;
else return x;
}
}
bool internal_erase(const T &key){
if(!root) return false;
node_ptr x = root;
while(x){
if(key < x->key) x = x->l;
else if(x->key < key) x = x->r;
else break;
}
if(!x) return false;
node_ptr t = nullptr;
if(!x->l && !x->r){
if(x->is_root()) root = nullptr;
else{
(x->is_left() ? x->p->l : x->p->r) = nullptr;
x->p->update();
}
t = x->p;
delete x;
}
else if(x->l && !x->r){
x->l->p = x->p;
if(x->is_root()) root = x->l;
else{
(x->is_left() ? x->p->l : x->p->r) = x->l;
x->p->update();
}
t = x->p;
delete x;
}
else if(x->r && !x->l){
x->r->p = x->p;
if(x->is_root()) root = x->r;
else{
(x->is_left() ? x->p->l : x->p->r) = x->r;
x->p->update();
}
t = x->p;
delete x;
}
else{
auto y = x->r; while(y->l) y = y->l;
swap(x->key, y->key);
(y->is_left() ? y->p->l : y->p->r) = y->r;
if(y->r) y->r->p = y->p;
t = y->p;
delete y;
}
x = t;
iterative_balance(x);
return true;
}
bool assertion(node_ptr x){
// check recursive
if(x->l && !assertion(x->l)) return false;
if(x->r && !assertion(x->r)) return false;
// check binary search tree structure
if(x->l && !(x->l->key < x->key)) return false;
if(x->r && !(x->key < x->r->key)) return false;
// check edge link
if(x->l && x->l->p != x) return false;
if(x->r && x->r->p != x) return false;
// from now, children have correct depth and balance factor
if(x->dep != max({-1, get_depth(x->l), get_depth(x->r)}) + 1) return false;
if(x->bf != get_depth(x->l) - get_depth(x->r)) return false;
if(x->sz != 1 + get_size(x->l) + get_size(x->r)) return false;
// from now, x has correct depth and balance factor
if(abs(x->bf) > 1) return false;
// nice avl tree!
return true;
}
void print(node_ptr x){
cout << x->key << " : " << (x->l ? x->l->key : -1) << " " << (x->r ? x->r->key : -1) << "\n";
if(x->l) print(x->l);
if(x->r) print(x->r);
}
public:
avl_tree() : root(nullptr) {}
bool insert(const T &key){ return internal_insert(key); }
bool erase(const T &key){ return internal_erase(key); }
bool find(const T &key){ return internal_find(key); }
int size(){ return get_size(root); }
bool assertion(){ return !root || assertion(root); }
void print(){
if(root) print(root);
cout << "\n";
}
};
bool self_stress_test(int n, unsigned seed, bool check_balance){
mt19937 gen(seed);
avl_tree<int> T; int S = 0;
vector<int> V(n*4), C(n);
for(int i=0; i<V.size(); i++) V[i] = i / 4;
shuffle(V.begin(), V.end(), gen);
for(int i=0; i<V.size(); i++){
bool op_ins = C[V[i]] < 2, op_del = C[V[i]] >= 2, op_flag = C[V[i]] % 2 == 0;
bool pre_exist = C[V[i]] == 1 || C[V[i]] == 2, post_exist = C[V[i]] < 2;
C[V[i]] += 1;
if(T.find(V[i]) != pre_exist) return false;
if(op_ins && T.insert(V[i]) != op_flag) return false;
if(op_del && T.erase(V[i]) != op_flag) return false;
if(T.find(V[i]) != post_exist) return false;
if(op_flag) S += op_ins ? 1 : -1;
if(S != T.size()) return false;
if(check_balance && !T.assertion()) return false;
}
return true;
}
bool stress_with_std_set(int n, unsigned seed, bool check_balance){ // O(n log n)
mt19937 gen(seed);
uniform_int_distribution<int> rnd(1, 50);
set<int> T1; avl_tree<int> T2;
for(int i=0; i<n; i++){
int now = rnd(gen);
bool pre_exist = T1.count(now), post_exist = !pre_exist;
bool op_ins = post_exist, op_del = pre_exist, op_flag = true;
if(T2.find(now) != pre_exist) return false;
if(op_ins){
T1.insert(now);
if(T2.insert(now) != op_flag) return false;
}
if(op_del){
T1.erase(now);
if(T2.erase(now) != op_flag) return false;
}
if(T2.find(now) != post_exist) return false;
if(T1.size() != T2.size()) return false;
if(check_balance && !T2.assertion()) return false;
}
return true;
}
int main(){
ios_base::sync_with_stdio(false); cin.tie(nullptr);
assert(self_stress_test(5000, 0x917917, true));
assert(self_stress_test(500000, 0x917917, false));
assert(stress_with_std_set(5000, 0x917917, true));
assert(stress_with_std_set(500000, 0x917917, false));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment