Last active
August 29, 2015 14:07
-
-
Save zeroliu/fade29aecc1d6cbd32ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
IntervalST.java | |
Below is the syntax highlighted version of IntervalST.java from §7.3 Geometric Intersection. | |
/************************************************************************* | |
* Compilation: javac IntervalST.java | |
* Execution: java IntervalST | |
* Dependencies: Interval1D.java | |
* | |
* Interval search tree implemented using a randomized BST. | |
* | |
* Duplicate policy: if an interval is inserted that already | |
* exists, the new value overwrite the old one | |
* | |
*************************************************************************/ | |
import java.util.LinkedList; | |
public class IntervalST<Value> { | |
private Node root; // root of the BST | |
// BST helper node data type | |
private class Node { | |
Interval1D interval; // key | |
Value value; // associated data | |
Node left, right; // left and right subtrees | |
int N; // size of subtree rooted at this node | |
int max; // max endpoint in subtree rooted at this node | |
Node(Interval1D interval, Value value) { | |
this.interval = interval; | |
this.value = value; | |
this.N = 1; | |
this.max = interval.high; | |
} | |
} | |
/************************************************************************* | |
* BST search | |
*************************************************************************/ | |
public boolean contains(Interval1D interval) { | |
return (get(interval) != null); | |
} | |
// return value associated with the given key | |
// if no such value, return null | |
public Value get(Interval1D interval) { | |
return get(root, interval); | |
} | |
private Value get(Node x, Interval1D interval) { | |
if (x == null) return null; | |
int cmp = interval.compareTo(x.interval); | |
if (cmp < 0) return get(x.left, interval); | |
else if (cmp > 0) return get(x.right, interval); | |
else return x.value; | |
} | |
/************************************************************************* | |
* randomized insertion | |
*************************************************************************/ | |
public void put(Interval1D interval, Value value) { | |
if (contains(interval)) { StdOut.println("duplicate"); remove(interval); } | |
root = randomizedInsert(root, interval, value); | |
} | |
// make new node the root with uniform probability | |
private Node randomizedInsert(Node x, Interval1D interval, Value value) { | |
if (x == null) return new Node(interval, value); | |
if (Math.random() * size(x) < 1.0) return rootInsert(x, interval, value); | |
int cmp = interval.compareTo(x.interval); | |
if (cmp < 0) x.left = randomizedInsert(x.left, interval, value); | |
else x.right = randomizedInsert(x.right, interval, value); | |
fix(x); | |
return x; | |
} | |
private Node rootInsert(Node x, Interval1D interval, Value value) { | |
if (x == null) return new Node(interval, value); | |
int cmp = interval.compareTo(x.interval); | |
if (cmp < 0) { x.left = rootInsert(x.left, interval, value); x = rotR(x); } | |
else { x.right = rootInsert(x.right, interval, value); x = rotL(x); } | |
return x; | |
} | |
/************************************************************************* | |
* deletion | |
*************************************************************************/ | |
private Node joinLR(Node a, Node b) { | |
if (a == null) return b; | |
if (b == null) return a; | |
if (Math.random() * (size(a) + size(b)) < size(a)) { | |
a.right = joinLR(a.right, b); | |
fix(a); | |
return a; | |
} | |
else { | |
b.left = joinLR(a, b.left); | |
fix(b); | |
return b; | |
} | |
} | |
// remove and return value associated with given interval; | |
// if no such interval exists return null | |
public Value remove(Interval1D interval) { | |
Value value = get(interval); | |
root = remove(root, interval); | |
return value; | |
} | |
private Node remove(Node h, Interval1D interval) { | |
if (h == null) return null; | |
int cmp = interval.compareTo(h.interval); | |
if (cmp < 0) h.left = remove(h.left, interval); | |
else if (cmp > 0) h.right = remove(h.right, interval); | |
else h = joinLR(h.left, h.right); | |
fix(h); | |
return h; | |
} | |
/************************************************************************* | |
* Interval searching | |
*************************************************************************/ | |
// return an interval in data structure that intersects the given inteval; | |
// return null if no such interval exists | |
// running time is proportional to log N | |
public Interval1D search(Interval1D interval) { | |
return search(root, interval); | |
} | |
// look in subtree rooted at x | |
public Interval1D search(Node x, Interval1D interval) { | |
while (x != null) { | |
if (interval.intersects(x.interval)) return x.interval; | |
else if (x.left == null) x = x.right; | |
else if (x.left.max < interval.low) x = x.right; | |
else x = x.left; | |
} | |
return null; | |
} | |
// return *all* intervals in data structure that intersect the given interval | |
// running time is proportional to R log N, where R is the number of intersections | |
public Iterable<Interval1D> searchAll(Interval1D interval) { | |
LinkedList<Interval1D> list = new LinkedList<Interval1D>(); | |
searchAll(root, interval, list); | |
return list; | |
} | |
// look in subtree rooted at x | |
public boolean searchAll(Node x, Interval1D interval, LinkedList<Interval1D> list) { | |
boolean found1 = false; | |
boolean found2 = false; | |
boolean found3 = false; | |
if (x == null) | |
return false; | |
if (interval.intersects(x.interval)) { | |
list.add(x.interval); | |
found1 = true; | |
} | |
if (x.left != null && x.left.max >= interval.low) | |
found2 = searchAll(x.left, interval, list); | |
if (found2 || x.left == null || x.left.max < interval.low) | |
found3 = searchAll(x.right, interval, list); | |
return found1 || found2 || found3; | |
} | |
/************************************************************************* | |
* useful binary tree functions | |
*************************************************************************/ | |
// return number of nodes in subtree rooted at x | |
public int size() { return size(root); } | |
private int size(Node x) { | |
if (x == null) return 0; | |
else return x.N; | |
} | |
// height of tree (empty tree height = 0) | |
public int height() { return height(root); } | |
private int height(Node x) { | |
if (x == null) return 0; | |
return 1 + Math.max(height(x.left), height(x.right)); | |
} | |
/************************************************************************* | |
* helper BST functions | |
*************************************************************************/ | |
// fix auxilliar information (subtree count and max fields) | |
private void fix(Node x) { | |
if (x == null) return; | |
x.N = 1 + size(x.left) + size(x.right); | |
x.max = max3(x.interval.high, max(x.left), max(x.right)); | |
} | |
private int max(Node x) { | |
if (x == null) return Integer.MIN_VALUE; | |
return x.max; | |
} | |
// precondition: a is not null | |
private int max3(int a, int b, int c) { | |
return Math.max(a, Math.max(b, c)); | |
} | |
// right rotate | |
private Node rotR(Node h) { | |
Node x = h.left; | |
h.left = x.right; | |
x.right = h; | |
fix(h); | |
fix(x); | |
return x; | |
} | |
// left rotate | |
private Node rotL(Node h) { | |
Node x = h.right; | |
h.right = x.left; | |
x.left = h; | |
fix(h); | |
fix(x); | |
return x; | |
} | |
/************************************************************************* | |
* Debugging functions that test the integrity of the tree | |
*************************************************************************/ | |
// check integrity of subtree count fields | |
public boolean check() { return checkCount() && checkMax(); } | |
// check integrity of count fields | |
private boolean checkCount() { return checkCount(root); } | |
private boolean checkCount(Node x) { | |
if (x == null) return true; | |
return checkCount(x.left) && checkCount(x.right) && (x.N == 1 + size(x.left) + size(x.right)); | |
} | |
private boolean checkMax() { return checkMax(root); } | |
private boolean checkMax(Node x) { | |
if (x == null) return true; | |
return x.max == max3(x.interval.high, max(x.left), max(x.right)); | |
} | |
/************************************************************************* | |
* test client | |
*************************************************************************/ | |
public static void main(String[] args) { | |
int N = Integer.parseInt(args[0]); | |
// generate N random intervals and insert into data structure | |
IntervalST<String> st = new IntervalST<String>(); | |
for (int i = 0; i < N; i++) { | |
int low = (int) (Math.random() * 1000); | |
int high = (int) (Math.random() * 50) + low; | |
Interval1D interval = new Interval1D(low, high); | |
StdOut.println(interval); | |
st.put(interval, "" + i); | |
} | |
// print out tree statistics | |
System.out.println("height: " + st.height()); | |
System.out.println("size: " + st.size()); | |
System.out.println("integrity check: " + st.check()); | |
System.out.println(); | |
// generate random intervals and check for overlap | |
for (int i = 0; i < N; i++) { | |
int low = (int) (Math.random() * 100); | |
int high = (int) (Math.random() * 10) + low; | |
Interval1D interval = new Interval1D(low, high); | |
StdOut.println(interval + ": " + st.search(interval)); | |
StdOut.print(interval + ": "); | |
for (Interval1D x : st.searchAll(interval)) | |
StdOut.print(x + " "); | |
StdOut.println(); | |
StdOut.println(); | |
} | |
} | |
} | |
Copyright © 2002–2010, Robert Sedgewick and Kevin Wayne. | |
Last updated: Wed Jun 22 22:55:45 EDT 2011. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment