Created
December 7, 2014 18:05
-
-
Save Pliner/d1d60aac2836bbe19ac8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.List; | |
public class AvlTree<K extends Comparable<K>, V> { | |
private static class Node<K extends Comparable<K>, V> { | |
public K key; | |
public V value; | |
public Node left; | |
public Node right; | |
public int height; | |
public int count; | |
public Node(K key, V value) { | |
this.key = key; | |
this.value = value; | |
this.height = 1; | |
this.count = 1; | |
} | |
} | |
private Node<K, V> root; | |
public void add(K key, V value) { | |
root = addInternal(root, key, value); | |
} | |
public V findByIndex(int index) { | |
if(index > root.count) | |
throw new IndexOutOfBoundsException(); | |
return findByIndexInternal(root, index).value; | |
} | |
public int indexAt(K key) { | |
int index = 0; | |
IndexHolder holder = new IndexHolder(index); | |
Node<K, V> node = indexAtInternal(root, key, holder); | |
if(node == null) | |
return -1; | |
return holder.index; | |
} | |
private Node<K, V> indexAtInternal(Node<K, V> node, K key, IndexHolder holder) { | |
if(node == null) | |
return null; | |
int compareResult = node.key.compareTo(key); | |
if(compareResult > 0) { | |
return indexAtInternal(node.left, key, holder); | |
} | |
int leftSize = safeGetCount(node.left); | |
holder.index += leftSize; | |
if(compareResult == 0) | |
return node; | |
holder.index++; | |
return indexAtInternal(node.right, key, holder); | |
} | |
private static class IndexHolder { | |
public IndexHolder(int index) { | |
this.index = index; | |
} | |
public int index; | |
} | |
private Node<K, V> findByIndexInternal(Node<K, V> node, int index) { | |
if(node == null) | |
throw new IndexOutOfBoundsException(); | |
int leftSize = safeGetCount(node.left); | |
if(leftSize == index) | |
return node; | |
if(index <= leftSize) | |
return findByIndexInternal(node.left, index); | |
return findByIndexInternal(node.right, index - leftSize - 1); | |
} | |
public V find(K key) { | |
Node<K, V> node = findInternal(root, key); | |
if(node == null) | |
return null; | |
return node.value; | |
} | |
public int getHeight() { | |
return safeGetHeight(root); | |
} | |
public int getCount() { | |
return root == null ? 0 : root.count; | |
} | |
private Node<K, V> findInternal(Node<K, V> node, K key) { | |
if(node == null) | |
return null; | |
int compareResult = node.key.compareTo(key); | |
if(compareResult > 0) | |
return findInternal(node.left, key); | |
if(compareResult < 0) | |
return findInternal(node.right, key); | |
return node; | |
} | |
private Node<K, V> addInternal(Node<K, V> node, K key, V value) { | |
if(node == null) { | |
return new Node<K, V>(key, value); | |
} | |
int compareResult = node.key.compareTo(key); | |
if(compareResult > 0) { | |
node.left = addInternal(node.left, key, value); | |
fix(node); | |
return balance(node); | |
} else if(compareResult < 0) { | |
node.right = addInternal(node.right, key, value); | |
fix(node); | |
return balance(node); | |
} else { | |
node.value = value; | |
return node; | |
} | |
} | |
private void fix(Node<K, V> node) { | |
node.height = Math.max(safeGetHeight(node.right), safeGetHeight(node.left)) + 1; | |
node.count = safeGetCount(node.right) + safeGetCount(node.left) + 1; | |
} | |
private int safeGetCount(Node<K, V> node) { | |
return node == null ? 0 : node.count; | |
} | |
private int safeGetHeight(Node<K, V> node) { | |
return node == null ? 0 : node.height; | |
} | |
private Node<K, V> balance(Node<K, V> node) { | |
int factor = getFactor(node); | |
switch (factor) { | |
case -2: { | |
int childFactor = getFactor(node.left); | |
if(childFactor > 0) | |
node.left = rotateLeft(node.left); | |
return rotateRight(node); | |
} | |
case 2: { | |
int childFactor = getFactor(node.right); | |
if(childFactor < 0) | |
node.right = rotateRight(node.right); | |
return rotateLeft(node); | |
} | |
} | |
return node; | |
} | |
public void print() { | |
printInternal(root, ""); | |
} | |
private void printInternal(Node<K, V> node, String prefix) { | |
if(node == null) | |
return; | |
printInternal(node.right, prefix + " "); | |
System.out.println(prefix + node.key); | |
printInternal(node.left, prefix + " "); | |
} | |
private Node<K, V> rotateRight(Node<K, V> node) { | |
Node<K, V> left = node.left; | |
node.left = left.right; | |
left.right = node; | |
fix(node); | |
fix(left); | |
return left; | |
} | |
private Node rotateLeft(Node<K, V> node) { | |
Node<K, V> right = node.right; | |
node.right = right.left; | |
right.left = node; | |
fix(node); | |
fix(right); | |
return right; | |
} | |
private int getFactor(Node<K, V> node) { | |
return safeGetHeight(node.right) - safeGetHeight(node.left); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment