Skip to content

Instantly share code, notes, and snippets.

Last active March 13, 2024 06:48
Show Gist options
  • Save tipabu/6999625 to your computer and use it in GitHub Desktop.
Save tipabu/6999625 to your computer and use it in GitHub Desktop.
Order statistic tree based on AVL trees. Insertion, removal, search, ranking, and selection will all be O(log(N)), as it's self-balancing. Supports duplicate elements, with each element inserted after smaller elements and elements of the same value but before larger elements. See also: http://en.…
import java.util.Iterator;
import java.util.AbstractQueue;
import java.util.NoSuchElementException;
public class BalancedOrderStatisticTree<T extends Comparable<T>> extends AbstractQueue<T> {
* Nodes of the tree.
private class Node {
* The value at this node.
private T val;
* Get the value at this node.
* @return the value at this node.
public T value() { return val; }
* The parent of this node. If this node it the root of a tree, null.
public Node parent = null;
* The left child of this node.
public Node lChild = null;
* The right child of this node.
public Node rChild = null;
* Number of children to the left. (Used for order statistic-purposes.)
public int left = 0;
* Number of children to the left. (Used for order statistic-purposes.)
public int right = 0;
* The height of this subtree.
public int height = 1;
* Create a new node with the given value.
* @param item
* the value at this node
public Node(T item) {
val = item;
* Find the left-most (smallest) node under this one.
* @return the left-most subnode
public Node head() {
Node tmp = this;
while (tmp.lChild != null) {
tmp = tmp.lChild;
return tmp;
* Find the right-most (largest) node under this one.
* @return the right-most subnode
public Node tail() {
Node tmp = this;
while (tmp.rChild != null) {
tmp = tmp.rChild;
return tmp;
* Get the tree depth for this node. The root node has depth zero, its
* children depth one, etc.
* @return the depth of this node
public int depth() {
int c = 0;
Node tmp = this;
while (tmp.parent != null) {
tmp = tmp.parent;
return c;
private int heightLeft() {
int h = 0;
if (lChild != null) h = lChild.height;
return h;
private int heightRight() {
int h = 0;
if (rChild != null) h = rChild.height;
return h;
private void rotateLeftUp(int rightHeight) {
assert lChild != null;
Node child = lChild;
int hLeft = child.heightLeft(), hRight = child.heightRight();
lChild = child.rChild;
child.rChild = this;
child.parent = this.parent;
this.parent = child;
if (lChild != null) lChild.parent = this;
if (child.parent == null) {
root = child;
} else {
if (child.parent.lChild == this) {
child.parent.lChild = child;
} else {
assert child.parent.rChild == this;
child.parent.rChild = child;
left = child.right;
child.right = left + 1 + right;
if (hRight > rightHeight) {
height = hRight + 1;
} else {
height = rightHeight + 1;
if (hLeft > height) {
child.height = hLeft + 1;
} else {
child.height = height + 1;
assert child.verify();
private void rotateRightUp(int leftHeight) {
assert rChild != null;
Node child = rChild;
int hLeft = child.heightLeft(), hRight = child.heightRight();
rChild = child.lChild;
child.lChild = this;
child.parent = this.parent;
this.parent = child;
if (rChild != null) rChild.parent = this;
if (child.parent == null) {
root = child;
} else {
if (child.parent.lChild == this) {
child.parent.lChild = child;
} else {
assert child.parent.rChild == this;
child.parent.rChild = child;
right = child.left;
child.left = left + 1 + right;
if (hLeft > leftHeight) {
height = hLeft + 1;
} else {
height = leftHeight + 1;
if (hRight > height) {
child.height = hRight + 1;
} else {
child.height = height + 1;
assert child.verify();
private boolean updateHeight() {
int hLeft = heightLeft(), hRight = heightRight();
if (hLeft > hRight + 1) {
assert lChild != null;
int hLLeft = lChild.heightLeft(), hLRight = lChild.heightRight();
if (hLLeft < hLRight) {
parent.height = 0; // Force height of parent (who we just rotated up there) to update
return true;
} else if (hRight > hLeft + 1) {
assert rChild != null;
int hRLeft = rChild.heightLeft(), hRRight = rChild.heightRight();
if (hRLeft > hRRight) {
parent.height = 0; // Force height of parent (who we just rotated up there) to update
return true;
if (hLeft > hRight) {
if (height != hLeft + 1) {
height = hLeft + 1;
return true;
} else {
return false;
} else {
if (height != hRight + 1) {
height = hRight + 1;
return true;
} else {
return false;
* Get the position of this node in the larger tree. The left-most
* (smallest) node will have rank zero, the next will have rank one,
* etc.
* @return the position of this node
* @see #select(int)
public int rank() {
int c = left;
Node tmp = this;
while (tmp.parent != null) {
if (tmp.parent.lChild != tmp) {
c += tmp.parent.left + 1;
tmp = tmp.parent;
return c;
* Get the node at the given position within this node's tree.
* @param index
* the position of the node we want
* @return the node at the given position
* @see #rank()
public Node select(int index) {
if (index == left) {
return this;
} else if (index < left) {
} else {
return - left - 1);
* Insert a new node into the tree.
* @param n
* the new node
* @param useRight
* whether to use the right-most insertion point
private void insert(Node n, boolean useRight) {
int cmp = n.value().compareTo(value());
if (cmp == 0) {
if (useRight) {
cmp = 1;
} else {
cmp = -1;
if (cmp < 0) { // n.value < this.value
if (lChild == null) {
n.parent = this;
lChild = n;
if (rChild == null) {
for (Node p = parent; p != null; p = p.parent) {
if (!p.updateHeight()) break;
} else {
lChild.insert(n, useRight);
} else { // n.value > this.value
if (rChild == null) {
n.parent = this;
rChild = n;
if (lChild == null) {
for (Node p = parent; p != null; p = p.parent) {
if (!p.updateHeight()) break;
} else {
rChild.insert(n, useRight);
* Insert a new node into the tree to the right of any equivalent nodes.
* @param n
* the new node
* @see insert(Node,boolean)
public void insertRight(Node n) {
insert(n, true);
* Insert a new node into the tree to the left of any equivalent nodes.
* @param n
* the new node
* @see insert(Node,boolean)
public void insertLeft(Node n) {
insert(n, false);
* Get the next node in the tree. This is the node whose rank is one
* greater than the rank of the current node.
* @return the next-largest node, or null if there are no larger nodes
* @see #rank()
* @see #prev()
public Node next() {
if (rChild != null) {
// We have a right child; look under there
return rChild.head();
} else if (parent == null) {
return null;
} else if (parent.lChild == this) {
return parent;
} else {
Node tmp = this;
while (tmp.parent != null && tmp.parent.rChild == tmp) {
tmp = tmp.parent;
return tmp.parent;
* Get the previous node in the tree. This is the node whose rank is
* one less than the rank of the current node.
* @return the next-smallest node, or null if there are no smaller
* nodes
* @see #rank()
* @see #next()
public Node prev() {
if (lChild != null) {
return lChild.tail();
} else if (parent == null) {
return null;
} else if (parent.rChild == this) {
return parent;
} else {
Node tmp = this;
while (tmp.parent != null && tmp.parent.lChild == tmp) {
tmp = tmp.parent;
return tmp.parent;
* Find the node under this one with the given value.
* @param val
* the value we're searching for
* @param useRight
* whether to use the right-most matching node, or the left-most
* @return the subnode if one is found, or null
public Node find(T val, boolean useRight) {
int cmp = value().compareTo(val);
if (cmp < 0 || (cmp == 0 && useRight)) { // too early; seach right side
Node res = null;
if (rChild != null) res = rChild.find(val, useRight);
if (res == null && cmp == 0) return this;
return res;
} else { // too late; seach left side
Node res = null;
if (lChild != null) res = lChild.find(val, useRight);
if (res == null && cmp == 0) return this;
return res;
public Node findLeft(T val) {
return find(val, false);
public Node findRight(T val) {
return find(val, true);
public boolean verify() {
int hLeft = 0, hRight = 0;
if (parent != null)
assert parent.lChild == this || parent.rChild == this : "Parent of " + value() + " only has children " + (parent.lChild == null ? "(null)" : parent.lChild.value()) + " and " + (parent.rChild == null ? "(null)" : parent.rChild.value()) + " (parent = " + parent.value() + ")";
if (lChild != null) {
assert lChild.parent == this : "Left child (" + lChild.value() + ") has parent " + lChild.parent.value() + ", not " + value();
assert left == lChild.left + 1 + lChild.right : "Left count at node " + value() + " should be " + (lChild.left + 1 + lChild.right) + ", not " + left;
hLeft = lChild.height;
} else {
assert left == 0;
if (rChild != null) {
assert rChild.parent == this : "Right child (" + rChild.value() + ") has parent " + rChild.parent.value() + ", not " + value();
assert right == rChild.left + 1 + rChild.right : "Right count at node " + value() + " should be " + (rChild.left + 1 + rChild.right) + ", not " + right;
hRight = rChild.height;
} else {
assert right == 0;
if (hLeft > hRight) {
assert height == hLeft + 1 : "Height at node " + value() + " should be " + (hLeft + 1) + ", not " + height;
} else {
assert height == hRight + 1 : "Height at node " + value() + " should be " + (hRight + 1) + ", not " + height;
return true; // "false" would be if one of the assertions actually failed, in which case an exception is raised
public class TreeIterator implements Iterator<T> {
private Node curr, last;
private TreeIterator() {
last = null;
curr = head;
public boolean hasNext() {
return curr != null;
public T next() {
last = curr;
curr =;
return last.value();
public void remove() {
private Node root = null, head = null, tail = null;
public boolean offer(T item) {
return true;
public void addLeft(T item) {
insert(item, false);
public void addRight(T item) {
insert(item, true);
private void insert(T item, boolean useRight) {
Node n = new Node(item);
if (root == null) {
root = head = tail = n;
} else {
if (useRight) {
if (head.value().compareTo(item) > 0) head = n;
if (tail.value().compareTo(item) <= 0) tail = n;
} else {
if (head.value().compareTo(item) >= 0) head = n;
if (tail.value().compareTo(item) < 0) tail = n;
assert root.verify();
public T peek() {
if (head == null) return null;
return head.value();
public T poll() {
if (head == null) return null;
Node n = head;
return n.value();
public boolean remove(Object o) {
if (o == null) throw new NullPointerException();
java.lang.reflect.Type parentType = getClass().getGenericSuperclass();
java.lang.reflect.Type[] typeArgs = ((java.lang.reflect.ParameterizedType) parentType).getActualTypeArguments();
if (! ((Class<T>)typeArgs[0]).isAssignableFrom(o.getClass()) ) throw new ClassCastException();
return remove((T)o);
public boolean remove(T item) {
return removeRight(item);
private boolean remove(T item, boolean useRight) {
if (item == null) throw new NullPointerException();
if (root == null) return false;
Node n = root.find(item, useRight);
if (n == null) return false;
assert n.value().compareTo(item) == 0;
return removeNode(n);
public boolean removeLeft(T item) {
return remove(item, false);
public boolean removeRight(T item) {
return remove(item, true);
private boolean removeNode(Node n) {
// Can't remove nothing...
if (n == null) return false;
if (n == head) head =;
if (n == tail) tail = tail.prev();
if (n.lChild == null) {
// Decrement left/right counts
Node tmp = n;
while (tmp.parent != null) {
if (tmp.parent.lChild == tmp) {
} else {
assert tmp.parent.rChild == tmp;
tmp = tmp.parent;
// if we have a child, update its parentage
if (n.rChild != null) {
n.rChild.parent = n.parent;
// if we're the root node, update that
if (n.parent == null) {
root = n.rChild;
} else {
// update the parent's children
if (n.parent.lChild == n) {
n.parent.lChild = n.rChild;
} else {
assert n.parent.rChild == n;
n.parent.rChild = n.rChild;
for (Node p = n.parent; p != null; p = p.parent) {
if (!p.updateHeight()) break;
} else if (n.rChild == null) {
Node tmp = n;
while (tmp.parent != null) {
if (tmp.parent.lChild == tmp) {
} else {
assert tmp.parent.rChild == tmp;
tmp = tmp.parent;
n.lChild.parent = n.parent;
if (n.parent == null) {
root = n.lChild;
} else {
if (n.parent.lChild == n) {
n.parent.lChild = n.lChild;
} else {
assert n.parent.rChild == n;
n.parent.rChild = n.lChild;
for (Node p = n.parent; p != null; p = p.parent) {
if (!p.updateHeight()) break;
} else {
Node p = n.prev();
// Couple of assertions. If these fail, prev() is broken; there should be a node between p and n
assert (p.parent == n && n.lChild == p) || p == p.parent.rChild;
assert p.rChild == null;
// First, disconnect p from the existing tree
boolean wasLChild = false;
if (p.parent == n) {
wasLChild = true;
n.lChild = p.lChild;
if (p.lChild != null) {
p.lChild.parent = n;
} else {
p.parent.rChild = p.lChild;
if (p.lChild != null) {
p.lChild.parent = p.parent;
for (Node tmp = p; tmp.parent != null; tmp = tmp.parent) {
if (((wasLChild) && (tmp == p)) || (tmp.parent.lChild == tmp)) {
} else {
assert ((!wasLChild) && (tmp == p)) || (tmp.parent.rChild == tmp);
for (Node tmp = p; tmp.parent != null; tmp = tmp.parent) {
if (!tmp.parent.updateHeight()) break;
// Then, insert p where n was
p.parent = n.parent;
p.rChild = n.rChild;
p.lChild = n.lChild;
p.left = n.left;
p.right = n.right;
p.height = n.height;
// And clean up what used to point to n
if (p.rChild != null) {
p.rChild.parent = p;
if (p.lChild != null) {
p.lChild.parent = p;
if (n.parent == null) {
root = p;
} else {
if (n.parent.lChild == n) {
p.parent.lChild = p;
} else {
assert n.parent.rChild == n;
p.parent.rChild = p;
assert root == null || root.verify();
return true;
private int rank(T item, boolean useRight) {
if (item == null) throw new NullPointerException();
if (root == null) return -1;
Node n = root.find(item, useRight);
if (n == null) return -1;
assert n.value().compareTo(item) == 0;
return n.rank();
public int rankLeft(T item) {
return rank(item, false);
public int rankRight(T item) {
return rank(item, true);
public T select(int index) {
assert root == null || root.verify();
if (index < 0 || index > size()) throw new NoSuchElementException();
Node n =;
assert n != null;
//if (n == null) throw new NoSuchElementException();
return n.value();
public int size() {
if (root == null) return 0;
return root.left + 1 + root.right;
public int height() {
if (root == null) return 0;
return root.height;
public Iterator<T> iterator() {
return new TreeIterator();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment