Skip to content

Instantly share code, notes, and snippets.

@slaykovsky
Created June 1, 2020 01:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save slaykovsky/3c91603a17597c378584ee2fd07a4154 to your computer and use it in GitHub Desktop.
Save slaykovsky/3c91603a17597c378584ee2fd07a4154 to your computer and use it in GitHub Desktop.
#pragma once
#include <memory>
#include <optional>
template <typename k_type, typename v_type>
class tree
{
enum color : bool
{
RED = true,
BLACK = false,
};
struct node
{
std::shared_ptr<node> left{ nullptr };
std::shared_ptr<node> right{ nullptr };
std::weak_ptr<node> parent;
size_t size;
const k_type k;
v_type v;
color c = BLACK;
explicit node(const k_type k, const v_type v, const tree<k_type, v_type>::color c, const size_t size) : k{ k }, v{ v }, c{ c }, size{ size }
{
parent.reset();
}
};
std::shared_ptr<node> root;
bool is_red(const std::shared_ptr<node> node) const
{
if (node == nullptr)
return false;
return node->c;
}
std::shared_ptr<node> rotate_left(std::shared_ptr<node> h)
{
std::shared_ptr<node> x = h->right;
h->right = x->left;
x->left = h;
x->c = h->c;
h->c = RED;
return x;
}
std::shared_ptr<node> rotate_right(std::shared_ptr<node> h)
{
std::shared_ptr<node> x = h->left;
h->left = x->right;
x->right = h;
x->c = h->c;
h->c = RED;
return x;
}
void flip_colors(std::shared_ptr<node> h)
{
h->c = RED;
h->left->c = BLACK;
h->right->c = BLACK;
}
long height(const std::shared_ptr<node> node) const {
if (node == nullptr) {
return -1;
}
return 1 + std::max(height(node->left), height(node->right));
}
size_t size(const std::shared_ptr<node> node) const {
if (node == nullptr) {
return 0;
}
return node->size;
}
std::shared_ptr<node> put(std::shared_ptr<node> n, const k_type k, const v_type v) {
if (n == nullptr) {
return std::make_shared<node>(k, v, RED, 1);
}
if (k < n->k) {
n->left = put(n->left, k, v);
}
else if (k > n->k) {
n->right = put(n->right, k, v);
}
else {
n->v = v;
}
if (is_red(n->right) && !is_red(n->left)) {
n = rotate_left(n);
}
if (is_red(n->left) && is_red(n->left->left)) {
n = rotate_right(n);
}
if (is_red(n->left) && is_red(n->right)) {
flip_colors(n);
}
n->size = size(n->left) + size(n->right) + 1;
return n;
}
std::shared_ptr<node> move_red_left(std::shared_ptr<node> n) {
flip_colors(n);
if (is_red(n->right->left)) {
n->right = rotate_right(n->right);
n = rotate_left(n);
flip_colors(n);
}
return n;
}
std::shared_ptr<node> move_red_right(std::shared_ptr<node> n) {
flip_colors(n);
if (is_red(n->left->left)) {
n = rotate_right(n);
flip_colors(n);
}
return n;
}
std::shared_ptr<node> balance(std::shared_ptr<node> n) {
if (is_red(n->right)) {
n = rotate_left(n);
}
if (is_red(n->left) && is_red(n->left->left)) {
n = rotate_right(n);
}
if (is_red(n->left) && is_red(n->right)) {
flip_colors(n);
}
n->size = size(n->left) + size(n->right) + 1;
return n;
}
std::shared_ptr<node> delete_min(std::shared_ptr<node> n) {
if (n->left == nullptr) {
return nullptr;
}
if (!is_red(n->left) && !is_red(n->left->left)) {
n = move_red_left(n);
}
n->left = delete_min(n->left);
return balance(n);
}
std::shared_ptr<node> min(const std::shared_ptr<node> n) const {
if (n->left == nullptr) {
return n;
}
return min(n->left);
}
std::shared_ptr<node> max(const std::shared_ptr<node> n) const {
if (n->right == nullptr) {
return n;
}
return max(n->right);
}
public:
void put(const k_type k, const v_type v) {
root = put(root, k, v);
root->c = BLACK;
}
std::optional<v_type> get_key(const k_type k) const
{
std::shared_ptr<node> current = root;
while (current)
{
if (current->k == k)
{
return current->v;
}
if (k < current->k)
{
current = current->left;
}
else
{
current = current->right;
}
}
return std::nullopt;
}
bool empty() const
{
return root == nullptr;
}
long height() const {
return height(root);
}
size_t size() const {
return size(root);
}
void delete_min() {
if (empty()) {
return;
}
if (!is_red(root->left) && !is_red(root->right)) {
root->c = RED;
}
root = delete_min(root);
if (!empty()) {
root->c = BLACK;
}
}
std::optional<k_type> min() const {
if (empty()) {
return std::nullopt;
}
return min(root)->k;
}
std::optional<k_type> max() const {
if (empty()) {
return std::nullopt;
}
return min(root)->k;
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment