Created
January 7, 2024 09:56
-
-
Save justiceHui/a8406518891ca0dba9c7518f8224378b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <bits/stdc++.h> | |
using namespace std; | |
template<typename T> | |
class avl_tree{ | |
struct avl_node{ | |
avl_node *l, *r, *p; | |
int dep, bf, sz; T key; | |
avl_node(T key) : l(nullptr), r(nullptr), p(nullptr), dep(0), bf(0), sz(1), key(key) {} | |
friend int get_depth(avl_node *x){ return x ? x->dep : -1; } | |
friend int get_size(avl_node *x){ return x ? x->sz : 0; } | |
bool is_left() const { return this == p->l; } | |
bool is_root() const { return !p; } | |
void update(){ | |
dep = max({-1, get_depth(l), get_depth(r)}) + 1; | |
bf = get_depth(l) - get_depth(r); | |
sz = 1 + get_size(l) + get_size(r); | |
} | |
}; | |
using node_ptr = avl_node*; | |
node_ptr root; | |
void rotate(node_ptr x){ | |
if(x->is_left()) x->r && (x->r->p = x->p), x->p->l = x->r, x->r = x->p; | |
else x->l && (x->l->p = x->p), x->p->r = x->l, x->l = x->p; | |
if(x->p->is_root()) root = x; | |
else (x->p->is_left() ? x->p->p->l : x->p->p->r) = x; | |
auto t = x->p; x->p = t->p; t->p = x; | |
t->update(); x->update(); | |
} | |
void iterative_balance(node_ptr x){ | |
while(x != nullptr){ | |
x->update(); | |
node_ptr y = nullptr; | |
int cnt = 0; | |
if(x->bf == 2){ | |
if(x->l->bf != -1) y = x->l, cnt = 1; | |
else y = x->l->r, cnt = 2; | |
} | |
else if(x->bf == -2){ | |
if(x->r->bf != 1) y = x->r, cnt = 1; | |
else y = x->r->l, cnt = 2; | |
} | |
x = x->p; | |
for(int i=0; i<cnt; i++) rotate(y); | |
} | |
} | |
bool internal_insert(const T &key){ | |
if(!root){ root = new avl_node(key); return true; } | |
node_ptr x = root, *p = nullptr; | |
while(true){ | |
if(key < x->key){ | |
if(x->l) x = x->l; | |
else{ p = &x->l; break; } | |
} | |
else if(x->key < key){ | |
if(x->r) x = x->r; | |
else{ p = &x->r; break; } | |
} | |
else return false; | |
} | |
*p = new avl_node(key); | |
(*p)->p = x; | |
iterative_balance(x); | |
return true; | |
} | |
avl_node* internal_find(const T &key){ | |
auto x = root; | |
while(true){ | |
if(!x) return nullptr; | |
if(key < x->key) x = x->l; | |
else if(x->key < key) x = x->r; | |
else return x; | |
} | |
} | |
bool internal_erase(const T &key){ | |
if(!root) return false; | |
node_ptr x = root; | |
while(x){ | |
if(key < x->key) x = x->l; | |
else if(x->key < key) x = x->r; | |
else break; | |
} | |
if(!x) return false; | |
node_ptr t = nullptr; | |
if(!x->l && !x->r){ | |
if(x->is_root()) root = nullptr; | |
else{ | |
(x->is_left() ? x->p->l : x->p->r) = nullptr; | |
x->p->update(); | |
} | |
t = x->p; | |
delete x; | |
} | |
else if(x->l && !x->r){ | |
x->l->p = x->p; | |
if(x->is_root()) root = x->l; | |
else{ | |
(x->is_left() ? x->p->l : x->p->r) = x->l; | |
x->p->update(); | |
} | |
t = x->p; | |
delete x; | |
} | |
else if(x->r && !x->l){ | |
x->r->p = x->p; | |
if(x->is_root()) root = x->r; | |
else{ | |
(x->is_left() ? x->p->l : x->p->r) = x->r; | |
x->p->update(); | |
} | |
t = x->p; | |
delete x; | |
} | |
else{ | |
auto y = x->r; while(y->l) y = y->l; | |
swap(x->key, y->key); | |
(y->is_left() ? y->p->l : y->p->r) = y->r; | |
if(y->r) y->r->p = y->p; | |
t = y->p; | |
delete y; | |
} | |
x = t; | |
iterative_balance(x); | |
return true; | |
} | |
bool assertion(node_ptr x){ | |
// check recursive | |
if(x->l && !assertion(x->l)) return false; | |
if(x->r && !assertion(x->r)) return false; | |
// check binary search tree structure | |
if(x->l && !(x->l->key < x->key)) return false; | |
if(x->r && !(x->key < x->r->key)) return false; | |
// check edge link | |
if(x->l && x->l->p != x) return false; | |
if(x->r && x->r->p != x) return false; | |
// from now, children have correct depth and balance factor | |
if(x->dep != max({-1, get_depth(x->l), get_depth(x->r)}) + 1) return false; | |
if(x->bf != get_depth(x->l) - get_depth(x->r)) return false; | |
if(x->sz != 1 + get_size(x->l) + get_size(x->r)) return false; | |
// from now, x has correct depth and balance factor | |
if(abs(x->bf) > 1) return false; | |
// nice avl tree! | |
return true; | |
} | |
void print(node_ptr x){ | |
cout << x->key << " : " << (x->l ? x->l->key : -1) << " " << (x->r ? x->r->key : -1) << "\n"; | |
if(x->l) print(x->l); | |
if(x->r) print(x->r); | |
} | |
public: | |
avl_tree() : root(nullptr) {} | |
bool insert(const T &key){ return internal_insert(key); } | |
bool erase(const T &key){ return internal_erase(key); } | |
bool find(const T &key){ return internal_find(key); } | |
int size(){ return get_size(root); } | |
bool assertion(){ return !root || assertion(root); } | |
void print(){ | |
if(root) print(root); | |
cout << "\n"; | |
} | |
}; | |
bool self_stress_test(int n, unsigned seed, bool check_balance){ | |
mt19937 gen(seed); | |
avl_tree<int> T; int S = 0; | |
vector<int> V(n*4), C(n); | |
for(int i=0; i<V.size(); i++) V[i] = i / 4; | |
shuffle(V.begin(), V.end(), gen); | |
for(int i=0; i<V.size(); i++){ | |
bool op_ins = C[V[i]] < 2, op_del = C[V[i]] >= 2, op_flag = C[V[i]] % 2 == 0; | |
bool pre_exist = C[V[i]] == 1 || C[V[i]] == 2, post_exist = C[V[i]] < 2; | |
C[V[i]] += 1; | |
if(T.find(V[i]) != pre_exist) return false; | |
if(op_ins && T.insert(V[i]) != op_flag) return false; | |
if(op_del && T.erase(V[i]) != op_flag) return false; | |
if(T.find(V[i]) != post_exist) return false; | |
if(op_flag) S += op_ins ? 1 : -1; | |
if(S != T.size()) return false; | |
if(check_balance && !T.assertion()) return false; | |
} | |
return true; | |
} | |
bool stress_with_std_set(int n, unsigned seed, bool check_balance){ // O(n log n) | |
mt19937 gen(seed); | |
uniform_int_distribution<int> rnd(1, 50); | |
set<int> T1; avl_tree<int> T2; | |
for(int i=0; i<n; i++){ | |
int now = rnd(gen); | |
bool pre_exist = T1.count(now), post_exist = !pre_exist; | |
bool op_ins = post_exist, op_del = pre_exist, op_flag = true; | |
if(T2.find(now) != pre_exist) return false; | |
if(op_ins){ | |
T1.insert(now); | |
if(T2.insert(now) != op_flag) return false; | |
} | |
if(op_del){ | |
T1.erase(now); | |
if(T2.erase(now) != op_flag) return false; | |
} | |
if(T2.find(now) != post_exist) return false; | |
if(T1.size() != T2.size()) return false; | |
if(check_balance && !T2.assertion()) return false; | |
} | |
return true; | |
} | |
int main(){ | |
ios_base::sync_with_stdio(false); cin.tie(nullptr); | |
assert(self_stress_test(5000, 0x917917, true)); | |
assert(self_stress_test(500000, 0x917917, false)); | |
assert(stress_with_std_set(5000, 0x917917, true)); | |
assert(stress_with_std_set(500000, 0x917917, false)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment