Skip to content

Instantly share code, notes, and snippets.

@huynhsamha
Last active April 21, 2018 12:33
Show Gist options
  • Save huynhsamha/400fd7e10b9a6325c57a8638612d53e9 to your computer and use it in GitHub Desktop.
Save huynhsamha/400fd7e10b9a6325c57a8638612d53e9 to your computer and use it in GitHub Desktop.
AVL Tree (HCMUT)
#pragma once
namespace shama {
enum BalanceFactor { LH = -1, EH = 0, RH = 1 };
template <class T>
struct AVLNode {
T data;
AVLNode* L = nullptr;
AVLNode* R = nullptr;
BalanceFactor balance;
AVLNode(T value)
: data(value), L(nullptr), R(nullptr), balance(EH) {}
~AVLNode() {}
};
template <class T>
class AVLTree {
private:
AVLNode<T>* m_root;
int m_count = 0;
public:
AVLTree() : m_root(nullptr), m_count(0) {}
~AVLTree() { clearAVLTree(); }
int size() { return m_count; }
int count() { return m_count; }
int height() { return Height(m_root); }
AVLNode<T>* root() const { return m_root; }
void clearAVLTree() { ClearAVLNode(m_root); }
void printNLR() { PrintNLR(m_root); }
AVLNode<T>* insert(T value) {
AVLNode<T>* newNode = new AVLNode<T>(value);
bool taller = false;
m_count++;
return InsertAVLNode(m_root, newNode, taller);
}
bool remove(T value) {
if (m_root == nullptr) return false;
bool shorter = false, success = false;
RemoveAVLNode(m_root, value, shorter, success);
if (success) m_count--;
return success;
}
AVLNode<T>* find(T value) { return Find(m_root, value); }
protected:
int Height(AVLNode<T>* root) const;
AVLNode<T>* RotateLeft(AVLNode<T>* &root);
AVLNode<T>* RotateRight(AVLNode<T>* &root);
AVLNode<T>* LeftBalance(AVLNode<T>* &root, bool &taller);
AVLNode<T>* RightBalance(AVLNode<T>* &root, bool &taller);
AVLNode<T>* InsertAVLNode(AVLNode<T>* &root, AVLNode<T>* newNode, bool &taller);
AVLNode<T>* RemoveAVLNode(AVLNode<T>* &root, T target, bool &shorter, bool &success);
AVLNode<T>* LeftBalanceRemove(AVLNode<T>* &root, bool &shorter);
AVLNode<T>* RightBalanceRemove(AVLNode<T>* &root, bool &shorter);
AVLNode<T>* Find(AVLNode<T>* root, T value);
void ClearAVLNode(AVLNode<T>* &node);
void PrintNLR(AVLNode<T>* root);
};
template <class T>
void AVLTree<T>::ClearAVLNode(AVLNode<T>* &node) {
if (node) {
ClearAVLNode(node->L);
ClearAVLNode(node->R);
delete node;
node = nullptr;
}
}
template <class T>
void AVLTree<T>::PrintNLR(AVLNode<T>* node) {
if (node == nullptr) return;
cout << node->data << " ";
PrintNLR(node->L);
PrintNLR(node->R);
}
template <class T>
int AVLTree<T>::Height(AVLNode<T>* root) const {
int height = 0;
if (root) {
int left = Height(root->L);
int right = Height(root->R);
height = 1 + ((left > right) ? left : right);
}
return height;
}
template <class T>
AVLNode<T>* AVLTree<T>::RotateLeft(AVLNode<T>* &root) {
AVLNode<T>* tmp = root->R;
root->R = tmp->L;
tmp->L = root;
return tmp;
}
template <class T>
AVLNode<T>* AVLTree<T>::RotateRight(AVLNode<T>* &root) {
AVLNode<T>* tmp = root->L;
root->L = tmp->R;
tmp->R = root;
return tmp;
}
template <class T>
AVLNode<T>* AVLTree<T>::LeftBalance(AVLNode<T>* &root, bool &taller) {
AVLNode<T>* leftTree = root->L;
if (leftTree->balance == LH) {
root = RotateRight(root);
root->balance = EH;
leftTree->balance = EH;
taller = false;
}
else {
AVLNode<T>* rightTree = leftTree->R;
if (rightTree->balance == LH) {
root->balance = RH;
leftTree->balance = EH;
}
else if (rightTree->balance == EH) {
leftTree->balance = EH;
}
else {
root->balance = EH;
leftTree->balance = LH;
}
rightTree->balance = EH;
root->L = RotateLeft(leftTree);
root = RotateRight(root);
taller = false;
}
return root;
}
template <class T>
AVLNode<T>* AVLTree<T>::RightBalance(AVLNode<T>* &root, bool &taller) {
AVLNode<T>* rightTree = root->R;
if (rightTree->balance == RH) {
root = RotateLeft(root);
root->balance = EH;
rightTree->balance = EH;
taller = false;
}
else {
AVLNode<T>* leftTree = rightTree->L;
if (leftTree->balance == RH) {
root->balance = LH;
rightTree->balance = EH;
}
else if (leftTree->balance == EH) {
rightTree->balance = EH;
}
else {
root->balance = EH;
rightTree->balance = RH;
}
leftTree->balance = EH;
root->R = RotateRight(rightTree);
root = RotateLeft(root);
taller = false;
}
return root;
}
template <class T>
AVLNode<T>* AVLTree<T>::InsertAVLNode(AVLNode<T>* &root, AVLNode<T>* newNode, bool &taller) {
if (!root) {
root = newNode; taller = true;
return root;
}
if (newNode->data < root->data) {
root->L = InsertAVLNode(root->L, newNode, taller);
if (taller) {
if (root->balance == LH) {
root = LeftBalance(root, taller);
}
else if (root->balance == EH) {
root->balance = LH;
}
else {
root->balance = EH;
taller = false;
}
}
}
else {
root->R = InsertAVLNode(root->R, newNode, taller);
if (taller) {
if (root->balance == RH) {
root = RightBalance(root, taller);
}
else if (root->balance == EH) {
root->balance = RH;
}
else {
root->balance = EH;
taller = false;
}
}
}
return root;
}
template <class T>
AVLNode<T>* AVLTree<T>::RightBalanceRemove(AVLNode<T>* &root, bool &shorter) {
if (root->balance == LH) {
root->balance = EH;
return root;
}
if (root->balance == EH) {
root->balance = RH;
shorter = false;
return root;
}
AVLNode<T>* rightTree = root->R;
if (rightTree->balance == LH) {
AVLNode<T>* leftTree = rightTree->L;
if (leftTree->balance == LH) {
rightTree->balance = RH;
root->balance = EH;
}
else if (leftTree->balance == EH) {
root->balance = LH;
rightTree->balance = EH;
}
else {
root->balance = LH;
rightTree->balance = EH;
}
leftTree->balance = EH;
root->R = RotateRight(rightTree);
root = RotateLeft(root);
}
else {
if (rightTree->balance != EH) {
root->balance = EH;
rightTree->balance = EH;
}
else {
root->balance = RH;
rightTree->balance = LH;
shorter = false;
}
root = RotateLeft(root);
}
return root;
}
template <class T>
AVLNode<T>* AVLTree<T>::LeftBalanceRemove(AVLNode<T>* &root, bool &shorter) {
if (root->balance == RH) {
root->balance = EH;
return root;
}
if (root->balance == EH) {
root->balance = LH;
shorter = false;
return root;
}
AVLNode<T>* leftTree = root->L;
if (leftTree->balance == RH) {
AVLNode<T>* rightTree = leftTree->R;
if (rightTree->balance == RH) {
leftTree->balance = LH;
root->balance = EH;
}
else if (rightTree->balance == EH) {
root->balance = RH;
leftTree->balance = EH;
}
else {
root->balance = RH;
leftTree->balance = EH;
}
rightTree->balance = EH;
root->L = RotateLeft(leftTree);
root = RotateRight(root);
}
else {
if (leftTree->balance != EH) {
root->balance = EH;
leftTree->balance = EH;
}
else {
root->balance = LH;
leftTree->balance = RH;
shorter = false;
}
root = RotateRight(root);
}
return root;
}
template <class T>
AVLNode<T>* AVLTree<T>::RemoveAVLNode(AVLNode<T>* &root, T target, bool &shorter, bool &success) {
if (root == nullptr) {
shorter = false; success = false;
return nullptr;
}
if (target < root->data) {
root->L = RemoveAVLNode(root->L, target, shorter, success);
if (shorter) {
root = RightBalanceRemove(root, shorter);
}
}
else if (target > root->data) {
root->R = RemoveAVLNode(root->R, target, shorter, success);
if (shorter) {
root = LeftBalanceRemove(root, shorter);
}
}
else {
AVLNode<T>* targetNode = root;
if (root->L == nullptr || root->R == nullptr) {
root = root->L == nullptr ? root->R : root->L;
success = true; shorter = true;
delete targetNode;
targetNode = nullptr;
}
else {
AVLNode<T>* tmp = root->L;
while (tmp->R) tmp = tmp->R;
root->data = tmp->data;
root->L = RemoveAVLNode(root->L, tmp->data, shorter, success);
if (shorter) {
root = RightBalanceRemove(root, shorter);
}
}
}
return root;
}
template <class T>
AVLNode<T>* AVLTree<T>::Find(AVLNode<T>* root, T value) {
if (root == nullptr) return nullptr;
if (root->data == value) return root;
if (value < root->data) return Find(root->L, value);
return Find(root->R, value);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment