Skip to content

Instantly share code, notes, and snippets.

@giuniu
Created December 14, 2012 10:31
Show Gist options
  • Save giuniu/4284371 to your computer and use it in GitHub Desktop.
Save giuniu/4284371 to your computer and use it in GitHub Desktop.
「アルゴリズムを学ぼう」よりAVL木の実装。
import static java.lang.Math.max;
import static java.lang.Math.min;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class AVLTree<T extends Comparable<? super T>> {
private Node root;
private class Node {
private T _value;
private Node _left;
private Node _right;
private int _balance;
T value() {
return _value;
}
void value(T v) {
_value = v;
}
Node left() {
return _left;
}
void setLeft(Node node) {
_left = node;
}
void setRight(Node node) {
_right = node;
}
Node right() {
return _right;
}
int balance() {
return _balance;
}
void addBalance(int diff) {
_balance += diff;
}
private Node(T value) {
this._value = value;
}
@Override
public String toString() {
return _value + (_left == null && _right == null ? "" : "<" + _balance + ">(" + _left + "," + _right + ")");
}
}
public boolean contains(T v) {
if (v == null)
throw new IllegalArgumentException("値がnull");
return contains(root, v);
}
private boolean contains(Node node, T v) {
if (node == null)
return false;
int comp = node.value().compareTo(v);
if (comp == 0)
return true;
else if (comp > 0)
return contains(node.left(), v);
else
return contains(node.right(), v);
}
public void insert(T v) {
if (v == null)
throw new IllegalArgumentException("値がnull");
if (root == null) {
root = new Node(v);
return;
}
insert(null, root, v);
}
private int insert(Node parent, Node node, T v) {
assert node != null;
if (node.value().compareTo(v) >= 0) {
int diff;
if (node.left() == null) {
node.setLeft(new Node(v));
diff = 1;
} else {
diff = insert(node, node.left(), v);
}
return achieveBalance(parent, node, diff, 0);
} else {
int diff;
if (node.right() == null) {
node.setRight(new Node(v));
diff = 1;
} else {
diff = insert(node, node.right(), v);
}
return achieveBalance(parent, node, 0, diff);
}
}
public void remove(T v) {
if (v == null)
throw new IllegalArgumentException("値がnull");
remove(null, root, v);
}
private int remove(Node parent, Node node, T v) {
if (node == null)
return 0;
if (node.value().compareTo(v) > 0) {
int diff = remove(node, node.left(), v);
return achieveBalance(parent, node, diff, 0);
} else if (node.value().compareTo(v) < 0) {
int diff = remove(node, node.right(), v);
return achieveBalance(parent, node, 0, diff);
} else {
assert node.value().equals(v);
if (node.left() != null) {
int diff = findMaxAndRemove(node, v);
return achieveBalance(parent, node, diff, 0);
} else if (node.right() != null) {
int diff = findMinAndRemove(node, v);
return achieveBalance(parent, node, 0, diff);
} else {
replace(parent, node, null);
return -1;
}
}
}
private int findMaxAndRemove(Node node, T v) {
Node maxParent = node;
Node max = node.left();
while (max.right() != null) {
maxParent = max;
max = max.right();
}
if (max.left() != null)
replace(maxParent, max, rotateRight(max));
swapValue(node, max);
return remove(node, node.left(), v);
}
private int findMinAndRemove(Node node, T v) {
Node minParent = node;
Node min = node.right();
while (min.left() != null) {
minParent = min;
min = min.left();
}
if (min.right() != null)
replace(minParent, min, rotateLeft(min));
swapValue(node, min);
return remove(node, node.right(), v);
}
private void swapValue(Node n1, Node n2) {
assert n1 != null && n2 != null;
T v = n1.value();
n1.value(n2.value());
n2.value(v);
}
private int achieveBalance(Node parent, Node node, int leftDiff, int rightDiff) {
assert (-1 <= node.balance() && node.balance() <= 1);
if (leftDiff == 0 && rightDiff == 0)
return 0;
int diff = 0;
if ((leftDiff > 0 && node.balance() <= 0) || (rightDiff > 0 && node.balance() >= 0))
diff++;
if ((leftDiff < 0 && node.balance() < 0) || (rightDiff < 0 && node.balance() > 0))
diff--;
node.addBalance(rightDiff - leftDiff);
assert (-2 <= node.balance() && node.balance() <= 2);
if (node.balance() == -2) {
if (node.left().balance() != 0)
diff--;
if (node.left().balance() == 1)
replace(node, node.left(), rotateLeft(node.left()));
replace(parent, node, rotateRight(node));
} else if (node.balance() == 2) {
if (node.right().balance() != 0)
diff--;
if (node.right().balance() == -1)
replace(node, node.right(), rotateRight(node.right()));
replace(parent, node, rotateLeft(node));
}
return diff;
}
/**
* 引数のノードをトップとする部分ツリーを左回転する
*
* @param node 回転するトップのノード
* @return 回転後のトップのノード
*/
private Node rotateLeft(Node node) {
assert node != null;
Node result = node.right();
assert result != null;
node.setRight(result.left());
result.setLeft(node);
int a = node.balance();
int b = result.balance();
int aa = a - max(b, 0) - 1;
node.addBalance(aa - a);
result.addBalance(min(aa, 0) - 1);
return result;
}
/**
* 引数のノードをトップとする部分ツリーを右回転する
*
* @param node 回転するトップのノード
* @return 回転後のトップのノード
*/
private Node rotateRight(Node node) {
assert node != null;
Node result = node.left();
assert result != null;
node.setLeft(result.right());
result.setRight(node);
int a = node.balance();
int b = result.balance();
int aa = a - min(b, 0) + 1;
node.addBalance(aa - a);
result.addBalance(max(aa, 0) + 1);
return result;
}
/**
* parentの子のノードをoldNodeからnewNodeに入れ替える。
*
* @param parent
* @param oldNode
* @param newNode
* @throws IllegalArgumentException parentの子にoldNodeが無かった時
*/
private void replace(Node parent, Node oldNode, Node newNode) {
if (parent == null) {
root = newNode;
return;
}
if (parent.left() == oldNode)
parent.setLeft(newNode);
else if (parent.right() == oldNode)
parent.setRight(newNode);
else
throw new IllegalArgumentException("子供が見つからなかった");
}
void selfCheck() {
check(root);
}
private int check(Node node) {
if (node == null)
return 0;
int lRank = check(node.left());
int rRank = check(node.right());
if (node.balance() != rRank - lRank)
// なにかおかしい
throw new IllegalStateException("rRank:" + rRank + " lRank:" + lRank + " node:" + node);
return max(lRank, rRank) + 1;
}
@Override
public String toString() {
return root == null ? null : root.toString();
}
public static void main(String[] args) {
int MAX = 30000;
List<Integer> list = new ArrayList<>();
for (int i = 1; i <= MAX; i++)
list.add(i);
AVLTree<Integer> tree = new AVLTree<>();
Random rand = new Random();
List<Integer> tmp = new ArrayList<>(list);
for (int i = MAX; i >= 1; i--) {
tree.insert(tmp.remove(rand.nextInt(i)));
tree.selfCheck();
}
System.out.println(tree);
System.out.println(tree.contains(18));
System.out.println(tree.contains(0));
tmp = new ArrayList<>(list);
for (int i = MAX; i >= 1; i--) {
tree.remove(tmp.remove(rand.nextInt(i)));
tree.selfCheck();
}
System.out.println(tree);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment