Created
June 11, 2015 05:11
-
-
Save anonymous/e0cf6d51f17c015d4f2a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public final class SplayTreeSet<K> { | |
private final class Node { | |
Node left = null; | |
Node right = null; | |
K key; | |
Node(K key) { | |
this.key = key; | |
} | |
} | |
private Node root; | |
private Node buffer = new Node(null); | |
private int size; | |
Comparator<K> comparator; | |
public SplayTreeSet(Comparator<K> comparator) { | |
this.comparator = comparator; | |
} | |
// <- (p,x) | |
private void rotateLeft(Node x, Node p) { | |
p.right = x.left; | |
x.left = p; | |
} | |
// (x,p) -> | |
private void rotateRight(Node x, Node p) { | |
p.left = x.right; | |
x.right = p; | |
} | |
private Node splay(Node x, K key) { | |
if (x == null) return null; | |
Node left = buffer; | |
Node right = buffer; | |
while (true) { | |
if (comparator.compare(key, x.key) < 0) { | |
Node y = x.left; | |
if (y == null) break; | |
if (comparator.compare(key, y.key) < 0) { // zig-zig | |
rotateRight(y, x); | |
x = y; | |
if (x.left == null) break; | |
} | |
// link right | |
right.left = x; | |
right = x; | |
// move left | |
x = x.left; | |
} else if (comparator.compare(key, x.key) > 0) { | |
Node y = x.right; | |
if (y == null) break; | |
if (comparator.compare(key, y.key) > 0) { // zig-zig | |
rotateLeft(y, x); | |
x = y; | |
if (x.right == null) break; | |
} | |
//link left | |
left.right = x; | |
left = x; | |
//move right | |
x = x.right; | |
} else break; | |
} | |
left.right = x.left; | |
right.left = x.right; | |
x.left = buffer.right; | |
x.right = buffer.left; | |
return x; | |
} | |
private Node join(Node left, Node right) { | |
if (left == null) return right; | |
if (right == null) return left; | |
left = splay(left, right.key); | |
left.right = right; | |
return left; | |
} | |
// left <= key, key < right | |
private Pair<Node, Node> split(Node x, K key) { | |
if (x == null) throw new RuntimeException(); | |
x = splay(x, key); | |
if (comparator.compare(key, x.key) >= 0) { // cut right | |
Node right = x.right; | |
x.right = null; | |
return new Pair<Node, Node>(x, right); | |
} else { // cut left | |
Node left = x.left; | |
x.left = null; | |
return new Pair<Node, Node>(left, x); | |
} | |
} | |
// API: | |
public int size() { | |
return size; | |
} | |
public boolean contains(K key) { | |
if (key == null) throw new IllegalArgumentException(); | |
if (root == null) return false; | |
root = splay(root, key); | |
return root.key.equals(key); | |
} | |
public void add(K key) { | |
if (key == null) throw new IllegalArgumentException(); | |
if (root == null) { | |
root = new Node(key); | |
++size; | |
} else { | |
root = splay(root, key); | |
if (!root.key.equals(key)) { | |
++size; | |
Pair<Node, Node> pair = split(root, key); | |
root = new Node(key); | |
root.left = pair.first; | |
root.right = pair.second; | |
} | |
} | |
} | |
public void remove(K key) { | |
if (key == null) throw new IllegalArgumentException(); | |
if (root == null) return; | |
root = splay(root, key); | |
if (root.key.equals(key)) { | |
root = join(root.left, root.right); | |
--size; | |
} | |
} | |
public K first() { | |
if (root == null) return null; | |
Node v = root; | |
while (v.left != null) v = v.left; | |
root = splay(root, v.key); | |
return root.key; | |
} | |
public K last() { | |
if (root == null) return null; | |
Node v = root; | |
while (v.right != null) v = v.right; | |
root = splay(root, v.key); | |
return root.key; | |
} | |
// x >= key | |
public K lowerBound(K key) { | |
if (key == null) throw new IllegalArgumentException(); | |
Pair<Node, Node> pair = split(root, key); | |
K res = null; | |
if (pair.first != null && key.equals(pair.first.key)) res = key; | |
else if (pair.second != null) { | |
pair.second = splay(pair.second, key); | |
res = pair.second.key; | |
} | |
root = join(pair.first, pair.second); | |
return res; | |
} | |
// x > key | |
public K upperBound(K key) { | |
if (key == null) throw new IllegalArgumentException(); | |
Pair<Node, Node> pair = split(root, key); | |
K res = null; | |
if (pair.second != null) { | |
pair.second = splay(pair.second, key); | |
res = pair.second.key; | |
} | |
root = join(pair.first, pair.second); | |
return res; | |
} | |
public void clear() { | |
root = null; | |
size = 0; | |
} | |
// DEBUG: | |
public ArrayList<K> getKeys() { | |
ArrayList<K> list = new ArrayList<K>(); | |
dfs(root, list); | |
return list; | |
} | |
private void dfs(Node v, ArrayList<K> list) { | |
if (v == null) return; | |
dfs(v.left, list); | |
list.add(v.key); | |
dfs(v.right, list); | |
} | |
@Override | |
public String toString() { | |
return getKeys().toString(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment