Skip to content

Instantly share code, notes, and snippets.

@Fugoes
Last active April 15, 2019 00:39
Show Gist options
  • Save Fugoes/69ebbe1b75c4ec96b914f72bc5e6d104 to your computer and use it in GitHub Desktop.
Save Fugoes/69ebbe1b75c4ec96b914f72bc5e6d104 to your computer and use it in GitHub Desktop.
AVL Tree
#ifndef CPP_AVL_HPP
#define CPP_AVL_HPP
#include <cstdint>
template<typename K>
struct avl_node {
using node_t = avl_node<K>;
node_t *cs_[2]{nullptr, nullptr};
node_t *p_{nullptr};
int8_t bf_{0};
uint8_t c_idx_{0};
K k_;
explicit avl_node(K k) : k_(k) {}
void set_child(int8_t c_idx, avl_node *child) {
cs_[c_idx] = child;
if (child != nullptr) {
child->p_ = this;
child->c_idx_ = c_idx;
}
}
void set_child_nonnull(int8_t c_idx, avl_node *child) {
cs_[c_idx] = child;
child->p_ = this;
child->c_idx_ = c_idx;
}
};
template<typename K, typename K_OP>
struct avl_tree {
using node_t = avl_node<K>;
node_t *root_;
size_t n_{0};
avl_tree() : root_(nullptr) {}
node_t *begin() const {
node_t *prev = nullptr;
for (node_t *node = root_; node != nullptr; node = node->cs_[0]) {
prev = node;
}
return prev;
}
node_t *end() const { return nullptr; }
node_t *rbegin() const {
node_t *prev = nullptr;
for (node_t *node = root_; node != nullptr; node = node->cs_[1]) {
prev = node;
}
return prev;
}
node_t *rend() const { return nullptr; }
node_t *next(node_t *node) const {
if (node->cs_[1] != nullptr) {
node = node->cs_[1];
while (node->cs_[0] != nullptr) node = node->cs_[0];
} else {
for (;;) {
uint8_t c_idx = node->c_idx_;
node = node->p_;
if (node == nullptr) return nullptr;
if (c_idx == 0) break;
}
}
return node;
}
node_t *prev(node_t *node) const {
if (node->cs_[0] != nullptr) {
node = node->cs_[0];
while (node->cs_[1] != nullptr) node = node->cs_[1];
} else {
for (;;) {
uint8_t c_idx = node->c_idx_;
node = node->p_;
if (node == nullptr) return nullptr;
if (c_idx == 1) break;
}
}
return node;
}
unsigned int rotate(node_t *node, node_t *parent, unsigned int reason) {
// return 0 if h(node) does not change
// return 1 if h(node) decrease by 1
int l = reason;
int r = 1 - l;
auto *P = node;
auto *C = P->cs_[l];
auto *N = C->cs_[r];
if (C->bf_ != 1 - 2 * l) {
/* P
* +-----+-----+
* C
* +--+--+
* N
* becomes:
* C
* +-----+-----+
* P
* +--+--+
* N
*
* 1) l(0), P(-2), C( 0) -> P(-1), C(+1), return 0
* 2) l(0), P(-2), C(-1) -> P( 0), C( 0), return 1
* 1) l(1), P(+2), C( 0) -> P(+1), C(-1), return 0
* 2) l(1), P(+2), C(+1) -> P( 0), C( 0), return 1
*/
if (parent == nullptr) root_ = C;
else parent->cs_[P->c_idx_] = C;
C->p_ = parent;
C->c_idx_ = P->c_idx_;
C->set_child_nonnull(r, P);
P->set_child(l, N);
C->bf_ += 1 - 2 * l;
P->bf_ = -C->bf_;
return C->bf_ == 0;
} else {
/* P
* +-----+-----+
* C
* +--+--+
* N
* +-+-+
* Nl Nr
* becomes:
* N
* +-----+-----+
* C P
* +--+--+ +--+--+
* Nl Nr
*
* 1) l(0), P(-2), C(+1), N(-1) -> P(+1), C( 0), N( 0), return 1
* 2) l(0), P(-2), C(+1), N( 0) -> P( 0), C( 0), N( 0), return 1
* 3) l(0), P(-2), C(+1), N(+1) -> P( 0), C(-1), N( 0), return 1
* 4) l(1), P(+2), C(-1), N(+1) -> P(-1), C( 0), N( 0), return 1
* 5) l(1), P(+2), C(-1), N( 0) -> P( 0), C( 0), N( 0), return 1
* 6) l(1), P(+2), C(-1), N(-1) -> P( 0), C(+1), N( 0), return 1
*/
auto *Nl = N->cs_[l];
auto *Nr = N->cs_[r];
if (parent == nullptr) root_ = N;
else parent->cs_[P->c_idx_] = N;
N->p_ = parent;
N->c_idx_ = P->c_idx_;
N->set_child_nonnull(l, C);
N->set_child_nonnull(r, P);
C->set_child(r, Nl);
P->set_child(l, Nr);
P->bf_ = (N->bf_ == 2 * l - 1) ? -N->bf_ : 0;
C->bf_ = (N->bf_ == 1 - 2 * l) ? -N->bf_ : 0;
N->bf_ = 0;
return 1;
}
}
avl_node<K> *find(const K k) {
if (root_ == nullptr) return nullptr;
auto *cursor = root_;
for (;;) {
if (K_OP::eq(k, cursor->k_)) return cursor;
int l = !K_OP::le(k, cursor->k_);
if (cursor->cs_[l] != nullptr) {
cursor = cursor->cs_[l];
} else {
return cursor;
}
}
}
void insert(const K k) {
n_++;
auto *node = new node_t{k};
if (root_ == nullptr) {
root_ = node;
} else {
auto *cursor = root_;
for (;;) {
int l = !K_OP::le(k, cursor->k_);
if (cursor->cs_[l] != nullptr) {
cursor = cursor->cs_[l];
} else {
cursor->cs_[l] = node;
node->p_ = cursor;
node->c_idx_ = l;
cursor->bf_ += 2 * l - 1;
break;
}
}
while (!(cursor->bf_ == 0 || cursor->p_ == nullptr)) {
if (cursor->p_->bf_ == 2 * cursor->c_idx_ - 1) {
rotate(cursor->p_, cursor->p_->p_, cursor->c_idx_);
break;
}
cursor->p_->bf_ += 2 * cursor->c_idx_ - 1;
cursor = cursor->p_;
}
}
}
void erase(K k) {
n_--;
auto *target = root_;
while (!K_OP::eq(k, target->k_)) target = target->cs_[!K_OP::le(k, target->k_)];
node_t *to_delete = target;
unsigned int l = (target->bf_ + 1) / 2;
if (target->cs_[0] != nullptr && target->cs_[1] != nullptr) {
int r = 1 - l;
to_delete = target->cs_[l];
while (to_delete->cs_[r] != nullptr) to_delete = to_delete->cs_[r];
}
target->k_ = std::move(to_delete->k_);
if (to_delete->p_ == nullptr) {
root_ = to_delete->cs_[l];
if (root_ != nullptr) root_->p_ = nullptr;
delete to_delete;
return;
}
auto *cursor = to_delete->p_;
auto c_idx = to_delete->c_idx_;
cursor->set_child(c_idx, to_delete->cs_[l]);
delete to_delete;
l = !c_idx;
while (cursor) {
auto *p = cursor->p_;
auto idx = cursor->c_idx_;
if (cursor->bf_ == 2 * l - 1) {
if (!rotate(cursor, p, l)) return;
} else {
cursor->bf_ += 2 * l - 1;
if (cursor->bf_ != 0) return;
}
cursor = p;
l = !idx;
}
}
};
#endif //CPP_AVL_HPP
#include <iostream>
#include "../include/avl.hpp"
struct int32_op {
inline static bool eq(int32_t l, int32_t r) { return l == r; }
inline static bool le(int32_t l, int32_t r) { return l < r; }
};
int main() {
avl_tree<int32_t, int32_op> tree;
for (int32_t i = 100; i > 0; i--) {
tree.insert(i);
}
for (auto iter = tree.begin(); iter != tree.end(); iter = tree.next(iter)) {
std::cout << iter->k_ << std::endl;
}
tree.erase(1);
for (auto iter = tree.rbegin(); iter != tree.rend(); iter = tree.prev(iter)) {
std::cout << iter->k_ << std::endl;
}
for (int32_t i = 1; i <= 101; i++) {
auto x = tree.find(i);
std::cout << (x->k_ == i) << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment