Skip to content

Instantly share code, notes, and snippets.

@TheDIM47
Created October 18, 2015 21:14
Show Gist options
  • Save TheDIM47/2f1c6a09817c7ad8c0ff to your computer and use it in GitHub Desktop.
Save TheDIM47/2f1c6a09817c7ad8c0ff to your computer and use it in GitHub Desktop.
Coursera Algorithms I
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.SET;
import edu.princeton.cs.algs4.StdDraw;
public class KdTree {
private static class Node {
private Point2D p; // the point
private RectHV rect; // the axis-aligned rectangle corresponding to this node
private Node lb; // the left/bottom subtree
private Node rt; // the right/top subtree
public Node(Point2D p) {
this.p = p;
}
public Node(Point2D p, RectHV rect) {
this(p);
this.rect = rect;
}
}
private int sz = 0;
private Node root = null;
// construct an empty set of points
public KdTree() {
}
// is the set empty?
public boolean isEmpty() {
return root == null;
}
// number of points in the set
public int size() {
return sz;
}
// add the point to the set (if it is not already in the set)
public void insert(Point2D p) {
if (p == null) {
throw new java.lang.NullPointerException();
}
if (root == null) {
root = new Node(p, new RectHV(0, 0, 1, 1));
sz++;
} else {
if (insertNode(root, p, true)) sz++;
}
}
private boolean insertNode(Node r, Point2D p, boolean vLevel) {
if (r.p.equals(p)) return false;
// if vLevel - compare X coordinates, else - Y - if less, go left
final boolean less;
if (vLevel) less = p.x() < r.p.x(); else less = p.y() < r.p.y();
// check that left or right node is empty
final Node m;
if (less) m = r.lb; else m = r.rt;
if (m == null) {
// if left or right node is empty - create new node
if (less) // left (bottom)
r.lb = new Node(p, getRectHV(r, less, vLevel));
else // right (top)
r.rt = new Node(p, getRectHV(r, less, vLevel));
return true;
} else // continue insertion
return insertNode(m, p, !vLevel);
}
private RectHV getRectHV(Node prev, boolean less, boolean vLevel) {
double x0 = prev.rect.xmin();
double y0 = prev.rect.ymin();
double x1 = prev.rect.xmax();
double y1 = prev.rect.ymax();
if (vLevel) {
if (less) {
x1 = prev.p.x();
} else {
x0 = prev.p.x();
}
} else {
if (less) {
y1 = prev.p.y();
} else {
y0 = prev.p.y();
}
}
return new RectHV(x0, y0, x1, y1);
}
// does the set contain point p?
public boolean contains(Point2D p) {
if (p == null) {
throw new java.lang.NullPointerException();
}
return findNode(root, p, true) != null;
}
private Node findNode(Node r, Point2D p, boolean vLevel) {
if (r == null) return null;
if (r.p.equals(p)) return r;
final Node m = getNextNode(r, p, vLevel);
return findNode(m, p, !vLevel);
}
private Node getNextNode(Node r, Point2D p, boolean vLevel) {
// if vLevel - compare X coordinates, else - Y - if less, go left
final boolean less;
if (vLevel) less = p.x() < r.p.x(); else less = p.y() < r.p.y();
// check that left or right node is empty
if (less) return r.lb; else return r.rt;
}
// draw all points to standard draw
public void draw() {
StdDraw.show(0);
drawNode(root, StdDraw.RED);
StdDraw.show();
}
private void drawNode(Node n, java.awt.Color color) {
if (n != null) {
// draw line
StdDraw.setPenColor(color);
StdDraw.setPenRadius(.001);
if (color == StdDraw.RED) // vertical
StdDraw.line(n.p.x(), n.rect.ymin(), n.p.x(), n.rect.ymax());
else // horizontal
StdDraw.line(n.rect.xmin(), n.p.y(), n.rect.xmax(), n.p.y());
// draw point
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(.01);
StdDraw.point(n.p.x(), n.p.y());
// select next level colors and draw nodes
final java.awt.Color nextColor;
if (color == StdDraw.RED) nextColor = StdDraw.BLUE; else nextColor = StdDraw.RED;
drawNode(n.lb, nextColor);
drawNode(n.rt, nextColor);
}
}
// all points that are inside the rectangle
public Iterable<Point2D> range(RectHV rect) {
if (rect == null) {
throw new java.lang.NullPointerException();
}
final SET<Point2D> set = new SET<>();
findRange(root, rect, set);
return set;
}
private void findRange(Node r, RectHV rect, SET<Point2D> set) {
if (r != null && r.rect.intersects(rect)) {
if (rect.contains(r.p)) {
set.add(r.p);
}
findRange(r.lb, rect, set);
findRange(r.rt, rect, set);
}
}
// a nearest neighbor in the set to point p; null if the set is empty
public Point2D nearest(Point2D p) {
if (p == null) {
throw new java.lang.NullPointerException();
}
if (root == null) return null;
final Pair m = findNearest(root, p, new Pair(root.p, Double.MAX_VALUE), true);
return m.point;
}
private static class Pair {
private Point2D point;
private double dist;
public Pair(Point2D point, double distance) {
this.point = point;
this.dist = distance;
}
}
private Pair findNearest(Node r, Point2D q, Pair best, boolean vLevel) {
if (r == null) return best;
final double dr = q.distanceSquaredTo(r.p);
if (dr < best.dist) {
best.point = r.p;
best.dist = dr;
}
final boolean less;
if (vLevel) less = q.x() < r.p.x(); else less = q.y() < r.p.y();
// check that left or right node is empty
final Node firstNode;
final Node secondNode;
if (less) {
firstNode = r.lb;
secondNode = r.rt;
} else {
firstNode = r.rt;
secondNode = r.lb;
}
if (firstNode != null && best.dist > firstNode.rect.distanceSquaredTo(q)) {
final Pair closest = findNearest(firstNode, q, best, !vLevel);
if (closest.dist < best.dist) {
best = closest;
}
}
if (secondNode != null && best.dist > secondNode.rect.distanceSquaredTo(q)) {
final Pair closest = findNearest(secondNode, q, best, !vLevel);
if (closest.dist < best.dist) {
best = closest;
}
}
return best;
}
// unit testing of the methods (optional)
public static void main(String[] args) {
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment