Skip to content

Instantly share code, notes, and snippets.

@qzdc00
Created January 7, 2015 20:26
Show Gist options
  • Save qzdc00/cef65b01cf2f555639b3 to your computer and use it in GitHub Desktop.
Save qzdc00/cef65b01cf2f555639b3 to your computer and use it in GitHub Desktop.
homework for algs4 week 5
public class KdTree {
private Node root;
private static final boolean X = true;
private static final boolean Y = false;
private class Node {
private Point2D p;
private double x;
private double y;
private int N;
private boolean coord;
private Node left, right;
public Node(double x, double y, int N, boolean coord) {
p = new Point2D(x, y);
this.x = x;
this.y = y;
this.coord = coord;
this.N = N;
}
}
public KdTree() {
// construct an empty set of points
root = null;
}
public boolean isEmpty() {
// is the set empty?
return root == null;
}
public int size() {
// number of points in the set
return size(root);
}
private int size(Node n) {
// number of points in the set
if(n == null) return 0;
else return n.N;
}
public void insert(Point2D p) {
// add the point to the set (if it is not already in the set)
root = insert(root, p, X);
}
private Node insert(Node n, Point2D p, boolean coord) {
// add the point to the set (if it is not already in the set)
if(n == null) {
return new Node(p.x(), p.y(), 1, coord);
}
if(n.coord == X) {
if(p.x() > n.x) {
n.right = insert(n.right, p, !n.coord);
} else if(p.x() < n.x) {
n.left = insert(n.left, p, !n.coord);
//之前就是这里忘了加else
} else {
n.left = insert(n.left, p, !n.coord);
}
} else if(n.coord == Y) {
if(p.y() > n.y) {
n.right = insert(n.right, p, !n.coord);
} else if(p.y() < n.y) {
n.left = insert(n.left, p, !n.coord);
} else {
n.left = insert(n.left, p, !n.coord);
}
}
n.N = size(n.left) + 1 + size(n.right);
return n;
}
public boolean contains(Point2D p) {
// does the set contain point p?
if(isEmpty()) {
return false;
}
Node cur = root;
while(cur != null) {
if(cur.coord == X) {
if(p.x() > cur.x) {
cur = cur.right;
} else if(p.x() < cur.x) {
cur = cur.left;
} else {
if(cur.x == p.x() && cur.y == p.y()) return true;
//还有这里之前也忘了
cur = cur.left;
}
} else if(cur.coord == Y) {
if(p.y() > cur.y) {
cur = cur.right;
} else if(p.y() < cur.y) {
cur = cur.left;
} else {
if(cur.x == p.x() && cur.y == p.y()) return true;
cur = cur.left;
}
}
}
return false;
}
// private boolean contains(Node n, Point2D p) {
// if(n == null) return false;
// if(n.coord == X) {
// if(p.x() > n.x) {
// return contains(n.right, p);
// } else if(p.x() < n.x) {
// return contains(n.left, p);
// } else {
// if(n.x == p.x() && n.y == p.y()) return true;
// return contains(n.left, p);
// }
//
// } else if(n.coord == Y) {
// if(p.y() > n.y) {
// return contains(n.right, p);
// } else if(p.y() < n.y) {
// return contains(n.left, p);
// } else {
// if(n.x == p.x() && n.y == p.y()) return true;
// return contains(n.left, p);
// }
// }
//
// }
//
public void draw() {
// draw all points to standard draw
draw(root);
}
private void draw(Node cur) {
// draw all points to standard draw
if(cur == null) return;
double x = cur.x;
double y = cur.y;
cur.p.draw();
draw(cur.right);
draw(cur.left);
return;
}
public Iterable<Point2D> range(RectHV rect) {
// all points that are inside the rectangle
Queue<Point2D> q = new Queue<Point2D>();
range(rect, root, q);
return q;
}
private void range(RectHV rect, Node cur, Queue<Point2D> q) {
if(cur == null) return;
if(cur.coord == X) {
if(rect.xmin() > cur.x) {
range(rect, cur.right, q);
} else if(rect.xmax() < cur.x) {
range(rect, cur.left, q);
} else {
if(rect.ymin() <= cur.y && rect.ymax() >= cur.y) {
q.enqueue(cur.p);
}
range(rect, cur.left, q);
range(rect, cur.right, q);
}
} else if(cur.coord == Y) {
if(rect.ymin() > cur.y) {
range(rect, cur.right, q);
} else if(rect.ymax() < cur.y) {
range(rect, cur.left, q);
} else {
if(rect.xmin() <= cur.x && rect.xmax() >= cur.x) {
q.enqueue(cur.p);
}
range(rect, cur.left, q);
range(rect, cur.right, q);
}
}
// if(cur == null) return;
// if(cur.y <= rect.ymax() && cur.y >= rect.ymin() && cur.x <= rect.xmax() && cur.x >= rect.xmin()) {
// q.enqueue(cur.p);
// }
// range(rect, cur.left, q);
// range(rect, cur.right, q);
}
public Point2D nearest(Point2D p) {
// a nearest neighbor in the set to point p; null if the set is empty
if(isEmpty()) return null;
else return nearest(root, p, root.p);
}
private Point2D nearest(Node n, Point2D p, Point2D min) {
// a nearest neighbor in the set to point p; null if the set is empty
if(n == null) return min;
double dist = p.distanceTo(n.p);
if(dist < min.distanceTo(p)) {
min = n.p;
}
if(n.coord == X) {
if(p.x() < n.x) {
min = nearest(n.left, p, min);
if(min.distanceTo(p) > Math.abs(p.x() - n.x)){
min = nearest(n.right, p, min);
}
} else {
min = nearest(n.right, p, min);
if(min.distanceTo(p) > Math.abs(p.x() - n.x)){
min = nearest(n.left, p, min);
}
}
} else if(n.coord == Y) {
if(p.y() < n.y) {
min = nearest(n.left, p, min);
if(min.distanceTo(p) > Math.abs(p.y() - n.y)){
min = nearest(n.right, p, min);
}
} else {
min = nearest(n.right, p, min);
if(min.distanceTo(p) > Math.abs(p.y() - n.y)){
min = nearest(n.left, p, min);
}
}
}
return min;
}
public static void main(String[] args) {
// unit testing of the methods (optional)
KdTree kdt = new KdTree();
int i;
for(i = 0; i < 10000; i++) {
Point2D p = new Point2D(Math.random(), Math.random());
kdt.insert(p);
}
// kdt.draw();
Point2D contain = new Point2D(0.45, 0.6);
kdt.insert(contain);
Point2D init = new Point2D(0.8, 0.75);
// StdDraw.setPenColor(StdDraw.RED);
// StdDraw.setPenRadius(.01);
// init.draw();
// Point2D near = kdt.nearest(init);
// near.draw();
// RectHV rect = new RectHV(0.2, 0.2, 0.5, 0.6);
// Queue<Point2D> q = (Queue<Point2D>) kdt.range(rect);
System.out.println(kdt.contains(contain));
// StdDraw.setPenColor(StdDraw.YELLOW);
// for(Point2D p : q) {
// p.draw();
// }
// StdDraw.setPenRadius(.005);
//rect.draw();
}
}
/*************************************************************************
* Compilation: javac Point2D.java
* Execution: java Point2D x0 y0 N
* Dependencies: StdDraw.java StdRandom.java
*
* Immutable point data type for points in the plane.
*
*************************************************************************/
import java.util.Arrays;
import java.util.Comparator;
/**
* The <tt>Point</tt> class is an immutable data type to encapsulate a
* two-dimensional point with real-value coordinates.
* <p>
* Note: in order to deal with the difference behavior of double and
* Double with respect to -0.0 and +0.0, the Point2D constructor converts
* any coordinates that are -0.0 to +0.0.
*
* For additional documentation, see <a href="/algs4/12oop">Section 1.2</a> of
* <i>Algorithms, 4th Edition</i> by Robert Sedgewick and Kevin Wayne.
*
* @author Robert Sedgewick
* @author Kevin Wayne
*/
public class Point2D implements Comparable<Point2D> {
/**
* Compares two points by x-coordinate.
*/
public static final Comparator<Point2D> X_ORDER = new XOrder();
/**
* Compares two points by y-coordinate.
*/
public static final Comparator<Point2D> Y_ORDER = new YOrder();
/**
* Compares two points by polar radius.
*/
public static final Comparator<Point2D> R_ORDER = new ROrder();
/**
* Compares two points by polar angle (between 0 and 2pi) with respect to this point.
*/
public final Comparator<Point2D> POLAR_ORDER = new PolarOrder();
/**
* Compares two points by atan2() angle (between -pi and pi) with respect to this point.
*/
public final Comparator<Point2D> ATAN2_ORDER = new Atan2Order();
/**
* Compares two points by distance to this point.
*/
public final Comparator<Point2D> DISTANCE_TO_ORDER = new DistanceToOrder();
private final double x; // x coordinate
private final double y; // y coordinate
/**
* Initializes a new point (x, y).
* @param x the x-coordinate
* @param y the y-coordinate
* @throws IllegalArgumentException if either <tt>x</tt> or <tt>y</tt>
* is <tt>Double.NaN</tt>, <tt>Double.POSITIVE_INFINITY</tt> or
* <tt>Double.NEGATIVE_INFINITY</tt>
*/
public Point2D(double x, double y) {
if (Double.isInfinite(x) || Double.isInfinite(y))
throw new IllegalArgumentException("Coordinates must be finite");
if (Double.isNaN(x) || Double.isNaN(y))
throw new IllegalArgumentException("Coordinates cannot be NaN");
if (x == 0.0) x = 0.0; // convert -0.0 to +0.0
if (y == 0.0) y = 0.0; // convert -0.0 to +0.0
this.x = x;
this.y = y;
}
/**
* Returns the x-coordinate.
* @return the x-coordinate
*/
public double x() {
return x;
}
/**
* Returns the y-coordinate.
* @return the y-coordinate
*/
public double y() {
return y;
}
/**
* Returns the polar radius of this point.
* @return the polar radius of this point in polar coordiantes: sqrt(x*x + y*y)
*/
public double r() {
return Math.sqrt(x*x + y*y);
}
/**
* Returns the angle of this point in polar coordinates.
* @return the angle (in radians) of this point in polar coordiantes (between -pi/2 and pi/2)
*/
public double theta() {
return Math.atan2(y, x);
}
/**
* Returns the angle between this point and that point.
* @return the angle in radians (between -pi and pi) between this point and that point (0 if equal)
*/
private double angleTo(Point2D that) {
double dx = that.x - this.x;
double dy = that.y - this.y;
return Math.atan2(dy, dx);
}
/**
* Is a->b->c a counterclockwise turn?
* @param a first point
* @param b second point
* @param c third point
* @return { -1, 0, +1 } if a->b->c is a { clockwise, collinear; counterclocwise } turn.
*/
public static int ccw(Point2D a, Point2D b, Point2D c) {
double area2 = (b.x-a.x)*(c.y-a.y) - (b.y-a.y)*(c.x-a.x);
if (area2 < 0) return -1;
else if (area2 > 0) return +1;
else return 0;
}
/**
* Returns twice the signed area of the triangle a-b-c.
* @param a first point
* @param b second point
* @param c third point
* @return twice the signed area of the triangle a-b-c
*/
public static double area2(Point2D a, Point2D b, Point2D c) {
return (b.x-a.x)*(c.y-a.y) - (b.y-a.y)*(c.x-a.x);
}
/**
* Returns the Euclidean distance between this point and that point.
* @param that the other point
* @return the Euclidean distance between this point and that point
*/
public double distanceTo(Point2D that) {
double dx = this.x - that.x;
double dy = this.y - that.y;
return Math.sqrt(dx*dx + dy*dy);
}
/**
* Returns the square of the Euclidean distance between this point and that point.
* @param that the other point
* @return the square of the Euclidean distance between this point and that point
*/
public double distanceSquaredTo(Point2D that) {
double dx = this.x - that.x;
double dy = this.y - that.y;
return dx*dx + dy*dy;
}
/**
* Compares this point to that point by y-coordinate, breaking ties by x-coordinate.
* @param that the other point
* @return { a negative integer, zero, a positive integer } if this point is
* { less than, equal to, greater than } that point
*/
public int compareTo(Point2D that) {
if (this.y < that.y) return -1;
if (this.y > that.y) return +1;
if (this.x < that.x) return -1;
if (this.x > that.x) return +1;
return 0;
}
// compare points according to their x-coordinate
private static class XOrder implements Comparator<Point2D> {
public int compare(Point2D p, Point2D q) {
if (p.x < q.x) return -1;
if (p.x > q.x) return +1;
return 0;
}
}
// compare points according to their y-coordinate
private static class YOrder implements Comparator<Point2D> {
public int compare(Point2D p, Point2D q) {
if (p.y < q.y) return -1;
if (p.y > q.y) return +1;
return 0;
}
}
// compare points according to their polar radius
private static class ROrder implements Comparator<Point2D> {
public int compare(Point2D p, Point2D q) {
double delta = (p.x*p.x + p.y*p.y) - (q.x*q.x + q.y*q.y);
if (delta < 0) return -1;
if (delta > 0) return +1;
return 0;
}
}
// compare other points relative to atan2 angle (bewteen -pi/2 and pi/2) they make with this Point
private class Atan2Order implements Comparator<Point2D> {
public int compare(Point2D q1, Point2D q2) {
double angle1 = angleTo(q1);
double angle2 = angleTo(q2);
if (angle1 < angle2) return -1;
else if (angle1 > angle2) return +1;
else return 0;
}
}
// compare other points relative to polar angle (between 0 and 2pi) they make with this Point
private class PolarOrder implements Comparator<Point2D> {
public int compare(Point2D q1, Point2D q2) {
double dx1 = q1.x - x;
double dy1 = q1.y - y;
double dx2 = q2.x - x;
double dy2 = q2.y - y;
if (dy1 >= 0 && dy2 < 0) return -1; // q1 above; q2 below
else if (dy2 >= 0 && dy1 < 0) return +1; // q1 below; q2 above
else if (dy1 == 0 && dy2 == 0) { // 3-collinear and horizontal
if (dx1 >= 0 && dx2 < 0) return -1;
else if (dx2 >= 0 && dx1 < 0) return +1;
else return 0;
}
else return -ccw(Point2D.this, q1, q2); // both above or below
// Note: ccw() recomputes dx1, dy1, dx2, and dy2
}
}
// compare points according to their distance to this point
private class DistanceToOrder implements Comparator<Point2D> {
public int compare(Point2D p, Point2D q) {
double dist1 = distanceSquaredTo(p);
double dist2 = distanceSquaredTo(q);
if (dist1 < dist2) return -1;
else if (dist1 > dist2) return +1;
else return 0;
}
}
/**
* Does this point equal y?
* @param other the other point
* @return true if this point equals the other point; false otherwise
*/
public boolean equals(Object other) {
if (other == this) return true;
if (other == null) return false;
if (other.getClass() != this.getClass()) return false;
Point2D that = (Point2D) other;
return this.x == that.x && this.y == that.y;
}
/**
* Return a string representation of this point.
* @return a string representation of this point in the format (x, y)
*/
public String toString() {
return "(" + x + ", " + y + ")";
}
/**
* Returns an integer hash code for this point.
* @return an integer hash code for this point
*/
public int hashCode() {
int hashX = ((Double) x).hashCode();
int hashY = ((Double) y).hashCode();
return 31*hashX + hashY;
}
/**
* Plot this point using standard draw.
*/
public void draw() {
StdDraw.point(x, y);
}
/**
* Plot a line from this point to that point using standard draw.
* @param that the other point
*/
public void drawTo(Point2D that) {
StdDraw.line(this.x, this.y, that.x, that.y);
}
/**
* Unit tests the point data type.
*/
public static void main(String[] args) {
int x0 = Integer.parseInt(args[0]);
int y0 = Integer.parseInt(args[1]);
int N = Integer.parseInt(args[2]);
StdDraw.setCanvasSize(800, 800);
StdDraw.setXscale(0, 100);
StdDraw.setYscale(0, 100);
StdDraw.setPenRadius(.005);
Point2D[] points = new Point2D[N];
for (int i = 0; i < N; i++) {
int x = StdRandom.uniform(100);
int y = StdRandom.uniform(100);
points[i] = new Point2D(x, y);
points[i].draw();
}
// draw p = (x0, x1) in red
Point2D p = new Point2D(x0, y0);
StdDraw.setPenColor(StdDraw.RED);
StdDraw.setPenRadius(.02);
p.draw();
// draw line segments from p to each point, one at a time, in polar order
StdDraw.setPenRadius();
StdDraw.setPenColor(StdDraw.BLUE);
Arrays.sort(points, p.POLAR_ORDER);
for (int i = 0; i < N; i++) {
p.drawTo(points[i]);
StdDraw.show(100);
}
}
}
public class PointSET {
private RedBlackBST<Point2D, Double> points;
public PointSET() {
// construct an empty set of points
points = new RedBlackBST<Point2D, Double>();
}
public boolean isEmpty() {
// is the set empty?
return points.size() == 0;
}
public int size() {
// number of points in the set
return points.size();
}
public void insert(Point2D p) {
// add the point to the set (if it is not already in the set)
points.put(p, p.x());
}
public boolean contains(Point2D p) {
// does the set contain point p?
return points.get(p) != null;
}
public void draw() {
// draw all points to standard draw
Queue<Point2D> q = (Queue<Point2D>)points.keys();
for(Point2D p : q) {
p.draw();
}
}
public Iterable<Point2D> range(RectHV rect) {
// all points that are inside the rectangle
Queue<Point2D> q = (Queue<Point2D>)points.keys();
Queue<Point2D> inside = new Queue<Point2D>();
if(q == null) return null;
for(Point2D p : q) {
if(rect.contains(p)) {
inside.enqueue(p);
}
}
return inside;
}
public Point2D nearest(Point2D p) {
// a nearest neighbor in the set to point p; null if the set is empty
if(isEmpty()) return null;
Queue<Point2D> q = (Queue<Point2D>)points.keys();
Point2D min = q.dequeue();
for(Point2D point : q) {
if(!point.equals(p)) {
if(p.distanceTo(point) < p.distanceTo(min)) {
min = point;
}
}
}
return min;
}
public static void main(String[] args) {
// unit testing of the methods (optional)
PointSET pset = new PointSET();
Point2D p1 = new Point2D(0.5, 0.5);
pset.insert(p1);
StdDraw.clear();
pset.draw();
Point2D p2 = new Point2D(0.7, 0.6);
pset.insert(p2);
StdDraw.clear();
pset.draw();
RectHV rect = new RectHV(0.4, 0.4, 0.6, 0.6);
Point2D p3 = new Point2D(0.8, 0.6);
pset.insert(p3);
StdDraw.clear();
pset.draw();
rect.draw();
Queue<Point2D> q = (Queue<Point2D>)pset.range(rect);
for(Point2D point : q) {
System.out.println("contains p? : " + point);
}
System.out.println("p1's nearest: " + pset.nearest(p1));
}
}
/*************************************************************************
* Compilation: javac RectHV.java
* Execution: java RectHV
* Dependencies: Point2D.java
*
* Implementation of 2D axis-aligned rectangle.
*
*************************************************************************/
public class RectHV {
private final double xmin, ymin; // minimum x- and y-coordinates
private final double xmax, ymax; // maximum x- and y-coordinates
// construct the axis-aligned rectangle [xmin, xmax] x [ymin, ymax]
public RectHV(double xmin, double ymin, double xmax, double ymax) {
if (xmax < xmin || ymax < ymin) {
throw new IllegalArgumentException("Invalid rectangle");
}
this.xmin = xmin;
this.ymin = ymin;
this.xmax = xmax;
this.ymax = ymax;
}
// accessor methods for 4 coordinates
public double xmin() { return xmin; }
public double ymin() { return ymin; }
public double xmax() { return xmax; }
public double ymax() { return ymax; }
// width and height of rectangle
public double width() { return xmax - xmin; }
public double height() { return ymax - ymin; }
// does this axis-aligned rectangle intersect that one?
public boolean intersects(RectHV that) {
return this.xmax >= that.xmin && this.ymax >= that.ymin
&& that.xmax >= this.xmin && that.ymax >= this.ymin;
}
// draw this axis-aligned rectangle
public void draw() {
StdDraw.line(xmin, ymin, xmax, ymin);
StdDraw.line(xmax, ymin, xmax, ymax);
StdDraw.line(xmax, ymax, xmin, ymax);
StdDraw.line(xmin, ymax, xmin, ymin);
}
// distance from p to closest point on this axis-aligned rectangle
public double distanceTo(Point2D p) {
return Math.sqrt(this.distanceSquaredTo(p));
}
// distance squared from p to closest point on this axis-aligned rectangle
public double distanceSquaredTo(Point2D p) {
double dx = 0.0, dy = 0.0;
if (p.x() < xmin) dx = p.x() - xmin;
else if (p.x() > xmax) dx = p.x() - xmax;
if (p.y() < ymin) dy = p.y() - ymin;
else if (p.y() > ymax) dy = p.y() - ymax;
return dx*dx + dy*dy;
}
// does this axis-aligned rectangle contain p?
public boolean contains(Point2D p) {
return (p.x() >= xmin) && (p.x() <= xmax)
&& (p.y() >= ymin) && (p.y() <= ymax);
}
// are the two axis-aligned rectangles equal?
public boolean equals(Object y) {
if (y == this) return true;
if (y == null) return false;
if (y.getClass() != this.getClass()) return false;
RectHV that = (RectHV) y;
if (this.xmin != that.xmin) return false;
if (this.ymin != that.ymin) return false;
if (this.xmax != that.xmax) return false;
if (this.ymax != that.ymax) return false;
return true;
}
// return a string representation of this axis-aligned rectangle
public String toString() {
return "[" + xmin + ", " + xmax + "] x [" + ymin + ", " + ymax + "]";
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment