Skip to content

Instantly share code, notes, and snippets.

@Pliner
Created December 7, 2014 18:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Pliner/d1d60aac2836bbe19ac8 to your computer and use it in GitHub Desktop.
Save Pliner/d1d60aac2836bbe19ac8 to your computer and use it in GitHub Desktop.
import java.util.ArrayList;
import java.util.List;
public class AvlTree<K extends Comparable<K>, V> {
private static class Node<K extends Comparable<K>, V> {
public K key;
public V value;
public Node left;
public Node right;
public int height;
public int count;
public Node(K key, V value) {
this.key = key;
this.value = value;
this.height = 1;
this.count = 1;
}
}
private Node<K, V> root;
public void add(K key, V value) {
root = addInternal(root, key, value);
}
public V findByIndex(int index) {
if(index > root.count)
throw new IndexOutOfBoundsException();
return findByIndexInternal(root, index).value;
}
public int indexAt(K key) {
int index = 0;
IndexHolder holder = new IndexHolder(index);
Node<K, V> node = indexAtInternal(root, key, holder);
if(node == null)
return -1;
return holder.index;
}
private Node<K, V> indexAtInternal(Node<K, V> node, K key, IndexHolder holder) {
if(node == null)
return null;
int compareResult = node.key.compareTo(key);
if(compareResult > 0) {
return indexAtInternal(node.left, key, holder);
}
int leftSize = safeGetCount(node.left);
holder.index += leftSize;
if(compareResult == 0)
return node;
holder.index++;
return indexAtInternal(node.right, key, holder);
}
private static class IndexHolder {
public IndexHolder(int index) {
this.index = index;
}
public int index;
}
private Node<K, V> findByIndexInternal(Node<K, V> node, int index) {
if(node == null)
throw new IndexOutOfBoundsException();
int leftSize = safeGetCount(node.left);
if(leftSize == index)
return node;
if(index <= leftSize)
return findByIndexInternal(node.left, index);
return findByIndexInternal(node.right, index - leftSize - 1);
}
public V find(K key) {
Node<K, V> node = findInternal(root, key);
if(node == null)
return null;
return node.value;
}
public int getHeight() {
return safeGetHeight(root);
}
public int getCount() {
return root == null ? 0 : root.count;
}
private Node<K, V> findInternal(Node<K, V> node, K key) {
if(node == null)
return null;
int compareResult = node.key.compareTo(key);
if(compareResult > 0)
return findInternal(node.left, key);
if(compareResult < 0)
return findInternal(node.right, key);
return node;
}
private Node<K, V> addInternal(Node<K, V> node, K key, V value) {
if(node == null) {
return new Node<K, V>(key, value);
}
int compareResult = node.key.compareTo(key);
if(compareResult > 0) {
node.left = addInternal(node.left, key, value);
fix(node);
return balance(node);
} else if(compareResult < 0) {
node.right = addInternal(node.right, key, value);
fix(node);
return balance(node);
} else {
node.value = value;
return node;
}
}
private void fix(Node<K, V> node) {
node.height = Math.max(safeGetHeight(node.right), safeGetHeight(node.left)) + 1;
node.count = safeGetCount(node.right) + safeGetCount(node.left) + 1;
}
private int safeGetCount(Node<K, V> node) {
return node == null ? 0 : node.count;
}
private int safeGetHeight(Node<K, V> node) {
return node == null ? 0 : node.height;
}
private Node<K, V> balance(Node<K, V> node) {
int factor = getFactor(node);
switch (factor) {
case -2: {
int childFactor = getFactor(node.left);
if(childFactor > 0)
node.left = rotateLeft(node.left);
return rotateRight(node);
}
case 2: {
int childFactor = getFactor(node.right);
if(childFactor < 0)
node.right = rotateRight(node.right);
return rotateLeft(node);
}
}
return node;
}
public void print() {
printInternal(root, "");
}
private void printInternal(Node<K, V> node, String prefix) {
if(node == null)
return;
printInternal(node.right, prefix + " ");
System.out.println(prefix + node.key);
printInternal(node.left, prefix + " ");
}
private Node<K, V> rotateRight(Node<K, V> node) {
Node<K, V> left = node.left;
node.left = left.right;
left.right = node;
fix(node);
fix(left);
return left;
}
private Node rotateLeft(Node<K, V> node) {
Node<K, V> right = node.right;
node.right = right.left;
right.left = node;
fix(node);
fix(right);
return right;
}
private int getFactor(Node<K, V> node) {
return safeGetHeight(node.right) - safeGetHeight(node.left);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment