Skip to content

Instantly share code, notes, and snippets.

@tanzaku

tanzaku/AVL.java Secret

Created November 29, 2016 15:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tanzaku/70d7d0199a9f871182546424efa69fb0 to your computer and use it in GitHub Desktop.
Save tanzaku/70d7d0199a9f871182546424efa69fb0 to your computer and use it in GitHub Desktop.
class Node {
int size;
int height;
Node l, r;
int value;
public Node(int value) {
this(value, null, null);
}
private Node(int value, Node l, Node r) {
this.value = value;
this.l = l;
this.r = r;
update();
}
public static int size(Node node) { return node == null ? 0 : node.size; }
public static int height(Node node) { return node == null ? 0 : node.height; }
private void update() {
this.size = 1;
this.size += size(this.l) + size(this.r);
this.height = Math.max(height(this.l), height(this.r)) + 1;
}
private Node rotL() {
assert(this.r != null);
final Node r = this.r;
this.r = r.l;
r.l = this;
this.update();
r.update();
return r;
}
private Node rotR() {
assert(this.l != null);
final Node l = this.l;
this.l = l.r;
l.r = this;
this.update();
l.update();
return l;
}
private static Node balance(Node n) {
int hl = height(n.l);
int hr = height(n.r);
if (hl > hr + 2) {
final int hll = height(n.l.l);
final int hlr = height(n.l.r);
if (hll < hlr) {
n.l = n.l.rotL();
n = n.rotR();
} else {
n = n.rotR();
}
} else if (hr > hl + 2) {
final int hrl = height(n.r.l);
final int hrr = height(n.r.r);
if (hrl > hrr) {
n.r = n.r.rotR();
n = n.rotL();
} else {
n = n.rotL();
}
} else {
n.update();
}
return n;
}
// O(|h(l) - h(r)|)
private static Node join(Node l, Node v, Node r) {
if (l == null) { return addFirst(r, v); } // O(h)
if (r == null) { return addLast(l, v); } // O(h)
int hl = height(l);
int hr = height(r);
// h(join(l, v, r)) <= max(h(l), h(r)) + 1 となる
if (hl > hr + 2) {
l.r = join(l.r, v, r); // O(h(r) - h(l)) * hl - hr <= 2 になったときに再帰が止まるので
return balance(l); // O(1)
}
if (hr > hl + 2) {
r.l = join(l, v, r.l); // O(h(l) - h(r)) * hr - hl <= 2 になったときに再帰が止まるので
return balance(r); // O(1)
}
// O(1)
v.l = l;
v.r = r;
v.update();
return v;
}
// concat
public static Node merge(Node l, Node r) {
if (l == null || r == null) return l == null ? r : l;
Node v = peekFirst(r); // O(h)
return join(l, v, removeFirst(r)); // O(h(l) - h(r))
}
// [0, k) [k, n)
public static Pair<Node, Node> split(Node n, int k) {
if (n == null) return new Pair<>(null, null);
// ・splitが再帰的に呼ばれる回数はO(h)
// ・split関数を再帰的に呼び出すときの、joinの呼ばれ方を考えると
// O(h1) + (O(h2) - O(h1)) + (O(h3) - O(h2)) + ... = O(h)
// (h1 < h2 < h3 < ...)
// なので、splitはO(h)
Node l = n.l;
Node r = n.r;
n.l = n.r = null;
n.update();
int sizeL = size(l);
if (k < sizeL) {
Pair<Node, Node> p = split(l, k);
p.second = join(p.second, n, r);
return p;
} else if (k > sizeL) {
Pair<Node, Node> p = split(r, k - sizeL - 1);
p.first = join(l, n, p.first);
return p;
} else {
return new Pair<>(l, addFirst(r, n));
}
}
public static Node insert(Node t, int k, int v) {
Pair<Node, Node> p = split(t, k);
return join(p.first, new Node(v), p.second);
}
public static Node erase(Node t, int k) {
Pair<Node, Node> p = split(t, k);
return merge(p.first, removeFirst(p.second));
}
// 以下、細々とした関数
private static Node addFirst(Node t, Node v) {
if (t == null) return v;
t.l = addFirst(t.l, v);
return balance(t);
}
private static Node addLast(Node t, Node v) {
if (t == null) return v;
t.r = addLast(t.r, v);
return balance(t);
}
private static Node peekFirst(Node t) {
while (t.l != null) t = t.l;
return t;
}
private static Node removeFirst(Node t) {
if (t.l == null) return t.r;
t.l = removeFirst(t.l);
return balance(t);
}
private static Node peekLast(Node t) {
while (t.r != null) t = t.r;
return t;
}
private static Node removeLast(Node t) {
if (t.r == null) return t.l;
t.r = removeLast(t.r);
return balance(t);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment