Skip to content

Instantly share code, notes, and snippets.

@mahfuzsust
Created August 6, 2021 13:24
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mahfuzsust/e67043588cbf42414ab3ecc0db1412fd to your computer and use it in GitHub Desktop.
Save mahfuzsust/e67043588cbf42414ab3ecc0db1412fd to your computer and use it in GitHub Desktop.
AVL (Balanced Binary Search Tree)
public class AVL<T extends Comparable<T>> extends BST<T> {
private static final String RIGHT_LEFT = "RL";
private static final String LEFT = "LL";
private static final String LEFT_RIGHT = "LR";
private static final String RIGHT = "RR";
private void rotateLeft(TreeNode<T> node) {
if (node.getParent() != null) {
if (node.getParent().getLeft() == node) {
node.getParent().setLeft(node.getRight());
} else {
node.getParent().setRight(node.getRight());
}
}
if (root == node) {
root = node.getRight();
}
TreeNode<T> left = null;
if (node.getRight() != null && node.getRight().getLeft() != null) {
left = node.getRight().getLeft();
left.setParent(node);
}
node.getRight().setLeft(node);
node.getRight().setParent(node.getParent());
node.setParent(node.getRight());
node.setRight(left);
updateChildBalanceFactor(node.getParent());
}
private void rotateRight(TreeNode<T> node) {
if (node.getParent() != null) {
if (node.getParent().getLeft() == node) {
node.getParent().setLeft(node.getLeft());
} else {
node.getParent().setRight(node.getLeft());
}
}
if (root == node) {
root = node.getLeft();
}
TreeNode<T> right = null;
if (node.getLeft() != null && node.getLeft().getRight() != null) {
right = node.getLeft().getRight();
right.setParent(node);
}
node.getLeft().setRight(node);
node.getLeft().setParent(node.getParent());
node.setParent(node.getLeft());
node.setLeft(right);
updateChildBalanceFactor(node.getParent());
}
private String getUnbalancedDirection(TreeNode<T> node, T key) {
String str = "";
TreeNode<T> item = node;
int count = 0;
while (item != null) {
if (count == 2 || key == item.getKey()) {
break;
}
int compareTo = key.compareTo(item.getKey());
if (compareTo < 0) {
str += "L";
item = item.getLeft();
} else {
str += "R";
item = item.getRight();
}
count++;
}
return str;
}
private void applyRotation(TreeNode<T> node, String direction) {
switch (direction) {
case LEFT:
rotateRight(node);
break;
case RIGHT:
rotateLeft(node);
break;
case LEFT_RIGHT:
rotateLeft(node.getLeft());
rotateRight(node);
break;
case RIGHT_LEFT:
rotateRight(node.getRight());
rotateLeft(node);
break;
default:
break;
}
}
private void fixAVLProperties(TreeNode<T> node, T key) {
TreeNode<T> item = node;
while (item != null) {
item.updateBalanceFactor();
if (item.isBalanced() == false) {
applyRotation(item, getUnbalancedDirection(item, key));
}
item = item.getParent();
}
}
@Override
public void insert(T key, Callback<TreeNode<T>> callback) {
super.insert(key, (node) -> {
fixAVLProperties(node, node.getKey());
if (callback != null) {
callback.success(node);
}
});
}
private void updateChildBalanceFactor(TreeNode<T> node) {
if (node == null) {
return;
}
updateChildBalanceFactor(node.getLeft());
updateChildBalanceFactor(node.getRight());
node.updateBalanceFactor();
}
private TreeNode<T> getUnbalancedNode(TreeNode<T> node) {
if (node == null) {
return null;
}
if (node.isBalanced() == false) {
return node;
}
TreeNode<T> unbalancedNode = getUnbalancedNode(node.getLeft());
if (unbalancedNode == null) {
return getUnbalancedNode(node.getRight());
}
return unbalancedNode;
}
private T getUnbalancedKey(TreeNode<T> node) {
if (node == null || node.isBalanced()) {
return null;
}
int count = 0;
TreeNode<T> item = node;
while (item != null) {
if (count >= 2) {
break;
}
if (item.getBalanceFactor() > 1) {
item = item.getLeft();
} else {
item = item.getRight();
}
count++;
}
return item.getKey();
}
@Override
public void delete(T key, Callback<TreeNode<T>> callback) {
super.delete(key, (node) -> {
TreeNode<T> parent = node.getParent();
if (parent == null) {
parent = root;
}
updateChildBalanceFactor(parent);
TreeNode<T> unbalancedNode = getUnbalancedNode(parent);
if (unbalancedNode != null) {
T unbalancedKey = getUnbalancedKey(unbalancedNode);
fixAVLProperties(unbalancedNode, unbalancedKey);
}
if (callback != null) {
callback.success(node);
}
});
}
}
public class BST<T extends Comparable<T>> {
protected TreeNode<T> root = null;
public BST() {
}
public TreeNode<T> getRoot() {
return root;
}
private TreeNode<T> findNode(TreeNode<T> node, T key) {
if (node == null)
return null;
int compareTo = key.compareTo(node.getKey());
if (compareTo == 0) {
return node;
} else if (compareTo < 0) {
return findNode(node.getLeft(), key);
} else {
return findNode(node.getRight(), key);
}
}
public int height(TreeNode<T> node) {
if (node == null)
return 0;
return Math.max(height(node.getLeft()) + 1, height(node.getRight()) + 1);
}
public boolean find(T key) {
return findNode(root, key) != null;
}
public void insert(T key, Callback<TreeNode<T>> callback) {
TreeNode<T> node = new TreeNode<T>(key);
if (root == null) {
root = node;
callback.success(node);
return;
}
TreeNode<T> item = root;
while (item != null) {
int compareTo = key.compareTo(item.getKey());
if (compareTo < 0) {
if (item.getLeft() == null) {
item.setLeft(node);
node.setParent(item);
break;
}
item = item.getLeft();
} else {
if (item.getRight() == null) {
item.setRight(node);
node.setParent(item);
break;
}
item = item.getRight();
}
}
callback.success(node);
}
public void inorder(TreeNode<T> node, Callback<TreeNode<T>> callback) {
if (node == null)
return;
inorder(node.getLeft(), callback);
callback.success(node);
inorder(node.getRight(), callback);
}
public void preorder(TreeNode<T> node, Callback<TreeNode<T>> callback) {
if (node == null)
return;
callback.success(node);
preorder(node.getLeft(), callback);
preorder(node.getRight(), callback);
}
public void postorder(TreeNode<T> node, Callback<TreeNode<T>> callback) {
if (node == null)
return;
postorder(node.getLeft(), callback);
postorder(node.getRight(), callback);
callback.success(node);
}
private void replaceNode(TreeNode<T> node, TreeNode<T> replace) {
if (node.getParent() == null) {
replace.setParent(null);
root = replace;
return;
}
if (node.getParent().getLeft() == node) {
node.getParent().setLeft(replace);
} else {
node.getParent().setRight(replace);
}
updateParent(replace, node.getParent());
}
public void delete(T key, Callback<TreeNode<T>> callback) {
TreeNode<T> node = findNode(root, key);
if (node == null) {
System.out.println("Key does not exist");
return;
}
if (node.getLeft() == null && node.getRight() == null) {
replaceNode(node, null);
} else if (node.getLeft() == null) {
replaceNode(node, node.getRight());
node.setRight(null);
} else if (node.getRight() == null) {
replaceNode(node, node.getLeft());
node.setLeft(null);
} else {
TreeNode<T> successor = getSuccessor(node);
replaceNode(successor, successor.getRight());
successor.setLeft(node.getLeft());
successor.setRight(node.getRight());
replaceNode(node, successor);
updateParent(successor.getLeft(), successor);
updateParent(successor.getRight(), successor);
node.setLeft(null);
node.setRight(null);
}
callback.success(node);
node = null;
}
private void updateParent(TreeNode<T> node, TreeNode<T> parent) {
if (node != null) {
node.setParent(parent);
}
}
private TreeNode<T> getSuccessor(TreeNode<T> node) {
var right = node.getRight();
TreeNode<T> ret = null;
while (right != null) {
ret = right;
right = right.getLeft();
}
return ret;
}
public T max(TreeNode<T> node) {
T value = null;
TreeNode<T> item = node;
while (item != null) {
item = item.getRight();
value = item.getKey();
}
return value;
}
public T min(TreeNode<T> node) {
T value = null;
TreeNode<T> item = node;
while (item != null) {
item = item.getLeft();
value = item.getKey();
}
return value;
}
}
public interface Callback<T> {
void success(T value);
}
public class TreeNode<T extends Comparable<T>> {
private T key;
private TreeNode<T> parent;
private TreeNode<T> left;
private TreeNode<T> right;
private int height;
private int balanceFactor;
public TreeNode() {
}
public TreeNode(T key) {
this.key = key;
}
public TreeNode<T> getParent() {
return parent;
}
public TreeNode<T> getLeft() {
return left;
}
public T getKey() {
return key;
}
public TreeNode<T> getRight() {
return right;
}
public void setKey(T key) {
this.key = key;
}
public void setParent(TreeNode<T> head) {
this.parent = head;
}
public void setLeft(TreeNode<T> left) {
this.left = left;
}
public void setRight(TreeNode<T> right) {
this.right = right;
}
public int getHeight() {
return height;
}
public int getBalanceFactor() {
return balanceFactor;
}
public void updateBalanceFactor() {
int left = 0, right = 0;
if (this.left != null) {
left = this.left.height;
}
if (this.right != null) {
right = this.right.height;
}
this.height = Math.max(left, right) + 1;
this.balanceFactor = left - right;
}
public boolean isBalanced() {
return this.getBalanceFactor() <= 1 && this.getBalanceFactor() >= -1;
}
@Override
public String toString() {
return "" + getKey();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment