Skip to content

Instantly share code, notes, and snippets.

Created June 11, 2015 05:11
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/e0cf6d51f17c015d4f2a to your computer and use it in GitHub Desktop.
Save anonymous/e0cf6d51f17c015d4f2a to your computer and use it in GitHub Desktop.
public final class SplayTreeSet<K> {
private final class Node {
Node left = null;
Node right = null;
K key;
Node(K key) {
this.key = key;
}
}
private Node root;
private Node buffer = new Node(null);
private int size;
Comparator<K> comparator;
public SplayTreeSet(Comparator<K> comparator) {
this.comparator = comparator;
}
// <- (p,x)
private void rotateLeft(Node x, Node p) {
p.right = x.left;
x.left = p;
}
// (x,p) ->
private void rotateRight(Node x, Node p) {
p.left = x.right;
x.right = p;
}
private Node splay(Node x, K key) {
if (x == null) return null;
Node left = buffer;
Node right = buffer;
while (true) {
if (comparator.compare(key, x.key) < 0) {
Node y = x.left;
if (y == null) break;
if (comparator.compare(key, y.key) < 0) { // zig-zig
rotateRight(y, x);
x = y;
if (x.left == null) break;
}
// link right
right.left = x;
right = x;
// move left
x = x.left;
} else if (comparator.compare(key, x.key) > 0) {
Node y = x.right;
if (y == null) break;
if (comparator.compare(key, y.key) > 0) { // zig-zig
rotateLeft(y, x);
x = y;
if (x.right == null) break;
}
//link left
left.right = x;
left = x;
//move right
x = x.right;
} else break;
}
left.right = x.left;
right.left = x.right;
x.left = buffer.right;
x.right = buffer.left;
return x;
}
private Node join(Node left, Node right) {
if (left == null) return right;
if (right == null) return left;
left = splay(left, right.key);
left.right = right;
return left;
}
// left <= key, key < right
private Pair<Node, Node> split(Node x, K key) {
if (x == null) throw new RuntimeException();
x = splay(x, key);
if (comparator.compare(key, x.key) >= 0) { // cut right
Node right = x.right;
x.right = null;
return new Pair<Node, Node>(x, right);
} else { // cut left
Node left = x.left;
x.left = null;
return new Pair<Node, Node>(left, x);
}
}
// API:
public int size() {
return size;
}
public boolean contains(K key) {
if (key == null) throw new IllegalArgumentException();
if (root == null) return false;
root = splay(root, key);
return root.key.equals(key);
}
public void add(K key) {
if (key == null) throw new IllegalArgumentException();
if (root == null) {
root = new Node(key);
++size;
} else {
root = splay(root, key);
if (!root.key.equals(key)) {
++size;
Pair<Node, Node> pair = split(root, key);
root = new Node(key);
root.left = pair.first;
root.right = pair.second;
}
}
}
public void remove(K key) {
if (key == null) throw new IllegalArgumentException();
if (root == null) return;
root = splay(root, key);
if (root.key.equals(key)) {
root = join(root.left, root.right);
--size;
}
}
public K first() {
if (root == null) return null;
Node v = root;
while (v.left != null) v = v.left;
root = splay(root, v.key);
return root.key;
}
public K last() {
if (root == null) return null;
Node v = root;
while (v.right != null) v = v.right;
root = splay(root, v.key);
return root.key;
}
// x >= key
public K lowerBound(K key) {
if (key == null) throw new IllegalArgumentException();
Pair<Node, Node> pair = split(root, key);
K res = null;
if (pair.first != null && key.equals(pair.first.key)) res = key;
else if (pair.second != null) {
pair.second = splay(pair.second, key);
res = pair.second.key;
}
root = join(pair.first, pair.second);
return res;
}
// x > key
public K upperBound(K key) {
if (key == null) throw new IllegalArgumentException();
Pair<Node, Node> pair = split(root, key);
K res = null;
if (pair.second != null) {
pair.second = splay(pair.second, key);
res = pair.second.key;
}
root = join(pair.first, pair.second);
return res;
}
public void clear() {
root = null;
size = 0;
}
// DEBUG:
public ArrayList<K> getKeys() {
ArrayList<K> list = new ArrayList<K>();
dfs(root, list);
return list;
}
private void dfs(Node v, ArrayList<K> list) {
if (v == null) return;
dfs(v.left, list);
list.add(v.key);
dfs(v.right, list);
}
@Override
public String toString() {
return getKeys().toString();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment