Skip to content

Instantly share code, notes, and snippets.

@geakstr
Last active December 26, 2015 21:09
Show Gist options
  • Save geakstr/7214038 to your computer and use it in GitHub Desktop.
Save geakstr/7214038 to your computer and use it in GitHub Desktop.
Set Implementation on AVL-tree
public class AVLSet<T extends Comparable<T>> implements Iterable<T> {
private Node<T> root;
private int size;
private boolean desc_order;
public AVLSet() {
this(false);
}
public AVLSet(boolean desc_order) {
this.root = null;
this.size = 0;
this.desc_order = desc_order;
}
private class Node<K> {
private K key;
private int height;
private Node<K> left, right;
public Node(K k) {
this.key = k;
this.left = null;
this.right = null;
this.height = 1;
}
}
public class Iterator<K> implements java.util.Iterator<K> {
private java.util.Stack<Node<K>> stack = new java.util.Stack<Node<K>>();
private Node<K> cur;
private boolean desc_order;
public Iterator(Node<K> root, boolean desc_order) {
this.cur = root;
this.desc_order = desc_order;
}
@Override
public K next() {
while (cur != null) {
stack.push(cur);
cur = desc_order ? cur.right : cur.left;
}
cur = stack.pop();
Node<K> node = cur;
cur = desc_order ? cur.left : cur.right;
return node.key;
}
@Override
public boolean hasNext() {
return (!stack.isEmpty() || cur != null);
}
@Override
public void remove() {
throw new UnsupportedOperationException("Not supported.");
}
}
@Override
public Iterator<T> iterator() {
return new Iterator<T>(root, desc_order);
}
public void set_desc_order() {
desc_order = true;
}
public void set_asc_order() {
desc_order = false;
}
public void add(T k) {
root = insert(root, k);
size++;
}
public void remove(T k) {
root = remove(root, k);
size--;
}
public boolean contains(T k) {
return contains(root, k);
}
public int size() {
return size;
}
public int height() {
return height(root);
}
public void print() {
print(root, 0);
}
private int height(Node<T> p) {
return p != null ? p.height : 0;
}
public String toString() {
StringBuilder ret = new StringBuilder();
toString(root, ret);
return "[" + ret.toString().substring(0, ret.length() - 2) + "]";
}
private void toString(Node<T> p, StringBuilder ret) {
if (p != null) {
toString(p.left, ret);
ret.append(p.key + ", ");
toString(p.right, ret);
}
}
private void print(Node<T> p, int l) {
if (p != null) {
print(p.left, l + 1);
for (int i = 0; i < l; ++i)
System.out.print("\t ");
System.out.print(p.key);
print(p.right, l + 1);
} else {
System.out.println("\n\n");
}
}
private Node<T> insert(Node<T> p, T k) {
if (p == null)
return new Node<T>(k);
if (k.compareTo(p.key) == 0) {
size--;
return p;
}
if (k.compareTo(p.key) < 0)
p.left = insert(p.left, k);
else
p.right = insert(p.right, k);
return balance(p);
}
private boolean contains(Node<T> p, T k) {
if (p == null)
return false;
if (k.compareTo(p.key) == 0)
return true;
if (k.compareTo(p.key) < 0)
return contains(p.left, k);
else
return contains(p.right, k);
}
private Node<T> remove(Node<T> p, T k) {
if (p == null) {
size++;
return null;
}
if (k.compareTo(p.key) < 0)
p.left = remove(p.left, k);
else if (k.compareTo(p.key) > 0)
p.right = remove(p.right, k);
else {
Node<T> q = p.left;
Node<T> r = p.right;
if (r == null)
return q;
Node<T> min = find_min(r);
min.right = remove_min(r);
min.left = q;
return balance(min);
}
return balance(p);
}
private int bfactor(Node<T> p) {
return height(p.right) - height(p.left);
}
private void fix_height(Node<T> p) {
int h1 = height(p.left), hr = height(p.right);
p.height = (h1 > hr ? h1 : hr) + 1;
}
private Node<T> rotate_right(Node<T> p) {
Node<T> q = p.left;
p.left = q.right;
q.right = p;
fix_height(p);
fix_height(q);
return q;
}
private Node<T> rotate_left(Node<T> q) {
Node<T> p = q.right;
q.right = p.left;
p.left = q;
fix_height(q);
fix_height(p);
return p;
}
private Node<T> balance(Node<T> p) {
fix_height(p);
if (bfactor(p) == 2) {
if (bfactor(p.right) < 0)
p.right = rotate_right(p.right);
return rotate_left(p);
}
if (bfactor(p) == -2) {
if (bfactor(p.left) > 0)
p.left = rotate_left(p.left);
return rotate_right(p);
}
return p;
}
private Node<T> find_min(Node<T> p) {
return p.left != null ? find_min(p.left) : p;
}
private Node<T> remove_min(Node<T> p) {
if (p.left == null)
return p.right;
p.left = remove_min(p.left);
return balance(p);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment