Skip to content

Instantly share code, notes, and snippets.

@davide99
Created August 29, 2020 19:42
Show Gist options
  • Save davide99/872265cd3a8be9236a6fd3e58759e900 to your computer and use it in GitHub Desktop.
Save davide99/872265cd3a8be9236a6fd3e58759e900 to your computer and use it in GitHub Desktop.
#ifndef BST_H
#define BST_H
#include <algorithm>
#include <iterator>
#include <memory>
#include <stdexcept>
//https://gist.github.com/phoemur/6dd18d608438373185f6a2457662c1c2
static const int ALLOWED_IMBALANCE = 1;
template<typename T>
class AVLTree {
public:
struct Interval {
T right, left;
std::int32_t index;
template<typename X = T>
Interval(X &&l, X &&r, std::int32_t id): right{std::forward<X>(r)}, left{std::forward<X>(l)}, index{id} {}
bool operator<=(const struct Interval &other) const {
return this->right < other.left;
}
bool operator<<(const T &val) const {
return val <= this->left;
}
bool operator>>(const T &val) const {
return this->right < val;
}
};
private:
struct Node {
struct Interval interval;
std::unique_ptr<Node> left;
std::unique_ptr<Node> right;
std::int32_t height;
template<typename X = T>
Node(X &&ele, std::unique_ptr<Node> &&lt, std::unique_ptr<Node> &&rt, std::int32_t h = 0):
left{std::move(lt)}, right{std::move(rt)}, interval{std::forward<Interval>(ele)}, height{h} {}
};
std::unique_ptr<Node> root;
std::int32_t sz;
public:
template<typename Iter>
AVLTree(Iter first, Iter last) : root{nullptr}, sz{0} {
using c_tp = typename std::iterator_traits<Iter>::value_type;
static_assert(std::is_constructible<T, c_tp>::value, "Type mismatch");
T left, right;
std::int32_t i = 0;
for (auto it = first; it != last - 1; ++it, ++i) {
right = *(it + 1);
left = *it;
if (left >= right)
throw std::logic_error("Left boundary can't be grater than the right one");
insert(Interval(left, right, i));
}
}
~AVLTree() noexcept = default;
std::int32_t search(const T &x) const noexcept {
return search(x, root);
}
private:
void insert(Interval &&first) {
insert_util(std::forward<Interval>(first), root);
++sz;
}
std::unique_ptr<Node> clone(const std::unique_ptr<Node> &node) const {
if (!node)
return nullptr;
else
return std::make_unique<Node>(node->interval, clone(node->left), clone(node->right), node->height);
}
std::int32_t height(const std::unique_ptr<Node> &node) const noexcept {
return node == nullptr ? -1 : node->height;
}
std::int32_t search(const T &x, const std::unique_ptr<Node> &node) const noexcept {
auto t = node.get();
while (t != nullptr)
if (t->interval >> x) {
t = t->right.get();
} else if (t->interval << x) {
t = t->left.get();
} else {
return t->interval.index;
}
return -1;
}
void insert_util(Interval &&x, std::unique_ptr<Node> &t) {
if (t == nullptr) {
t = std::make_unique<Node>(std::forward<Interval>(x), nullptr, nullptr);
} else if (x <= t->interval) {
insert_util(std::forward<Interval>(x), t->left);
} else if (t->interval <= x) {
insert_util(std::forward<Interval>(x), t->right);
}
balance(t);
}
void balance(std::unique_ptr<Node> &t) noexcept {
if (t == nullptr)
return;
if (height(t->left) - height(t->right) > ALLOWED_IMBALANCE) {
if (height(t->left->left) >= height(t->left->right))
rotateWithLeftChild(t);
else
doubleWithLeftChild(t);
} else if (height(t->right) - height(t->left) > ALLOWED_IMBALANCE) {
if (height(t->right->right) >= height(t->right->left))
rotateWithRightChild(t);
else
doubleWithRightChild(t);
}
t->height = std::max(height(t->left), height(t->right)) + 1;
}
void rotateWithLeftChild(std::unique_ptr<Node> &k2) noexcept {
auto k1 = std::move(k2->left);
k2->left = std::move(k1->right);
k2->height = std::max(height(k2->left), height(k2->right)) + 1;
k1->height = std::max(height(k1->left), k2->height) + 1;
k1->right = std::move(k2);
k2 = std::move(k1);
}
void rotateWithRightChild(std::unique_ptr<Node> &k1) noexcept {
auto k2 = std::move(k1->right);
k1->right = std::move(k2->left);
k1->height = std::max(height(k1->left), height(k1->right)) + 1;
k2->height = std::max(height(k2->right), k1->height) + 1;
k2->left = std::move(k1);
k1 = std::move(k2);
}
void doubleWithLeftChild(std::unique_ptr<Node> &k3) noexcept {
rotateWithRightChild(k3->left);
rotateWithLeftChild(k3);
}
void doubleWithRightChild(std::unique_ptr<Node> &k1) noexcept {
rotateWithLeftChild(k1->right);
rotateWithRightChild(k1);
}
};
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment