Week 5 編程作業:Kd-Trees

編程作業: Kd-Trees

100分關鍵在如果當前最短距離>區域到點最短距離時才找,區域到點的最短距離可先用點到線的距離比過

PointSET.java

import java.util.LinkedList;
import java.util.List;
import java.util.TreeSet;

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;

public class PointSET {

	private TreeSet<Point2D> set;

	public PointSET() {
		this.set = new TreeSet<>();
	}

	public boolean isEmpty() {
		return this.set.isEmpty();
	}

	public int size() {
		return this.set.size();
	}

	public void insert(Point2D p) {
		requireNonNull(p);
		this.set.add(p);
	}

	public boolean contains(Point2D p) {
		requireNonNull(p);
		return this.set.contains(p);
	}

	public void draw() {
		for (Point2D p : this.set)
			p.draw();
	}

	public Iterable<Point2D> range(RectHV rect) {
		requireNonNull(rect);
		List<Point2D> rtn = new LinkedList<>();
		for (Point2D anyP : this.set) {
			if (rect.contains(anyP)) {
				rtn.add(anyP);
			}
		}
		return rtn;
	}

	public Point2D nearest(Point2D p) {
		requireNonNull(p);

		double minDis = Double.MAX_VALUE;
		Point2D rtn = null;
		for (Point2D anyP : this.set) {
			double dis = p.distanceSquaredTo(anyP);
			if (dis < minDis) {
				minDis = dis;
				rtn = anyP;
			}
		}
		return rtn;
	}

	private static void requireNonNull(Object o) {
		if (o == null)
			throw new java.lang.IllegalArgumentException();
	}

}

KdTree.java

import java.util.LinkedList;
import java.util.List;
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;

public class KdTree {

	private Node root;
	private int size;

	public KdTree() {
	}

	public boolean isEmpty() {
		return this.root == null;
	}

	public int size() {
		return this.size;
	}

	public void insert(Point2D p) {
		requireNonNull(p);
		this.root = put(root, p, true);
	}

	private Node put(Node n, Point2D p, boolean byX) {
		if (n == null) {
			this.size++;
			Node rtn = new Node(p);
//			rtn.index = this.size;
			return rtn;
		}
		if (n.p.x() == p.x() && n.p.y() == p.y())
			return n;

		boolean goLeft = (byX && p.x() < n.p.x()) || (!byX && p.y() < n.p.y());

		if (goLeft) {
			n.left = put(n.left, p, !byX);
		} else {
			n.right = put(n.right, p, !byX);
		}
		return n;
	}

	public boolean contains(Point2D p) {
		requireNonNull(p);
		return find(this.root, p, true) != null;
	}

	private Node find(Node n, Point2D p, boolean byX) {
		if (n == null)
			return null;
		if (n.p.x() == p.x() && n.p.y() == p.y())
			return n;

		boolean goLeft = (byX && p.x() < n.p.x()) || (!byX && p.y() < n.p.y());

		if (goLeft) {
			return find(n.left, p, !byX);
		} else {
			return find(n.right, p, !byX);
		}
	}

	public void draw() {
		preorderDraw(this.root, true, new RectHV(0, 0, 1, 1));
	}

	private static void preorderDraw(Node n, boolean byX, RectHV rect) {
		if (n == null)
			return;

		StdDraw.setPenColor(StdDraw.BLACK);
		StdDraw.setPenRadius(0.01);
		n.p.draw();
//		StdDraw.text(n.p.x(), n.p.y(), String.valueOf(n.index));

		if (byX) {
			StdDraw.setPenColor(StdDraw.RED);
			StdDraw.setPenRadius(0.003);
			StdDraw.line(n.p.x(), rect.ymin(), n.p.x(), rect.ymax());
			preorderDraw(n.left, !byX, new RectHV(rect.xmin(), rect.ymin(), n.p.x(), rect.ymax()));
			preorderDraw(n.right, !byX, new RectHV(n.p.x(), rect.ymin(), rect.xmax(), rect.ymax()));
		} else {
			StdDraw.setPenColor(StdDraw.BLUE);
			StdDraw.setPenRadius(0.003);
			StdDraw.line(rect.xmin(), n.p.y(), rect.xmax(), n.p.y());
			preorderDraw(n.left, !byX, new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), n.p.y()));
			preorderDraw(n.right, !byX, new RectHV(rect.xmin(), n.p.y(), rect.xmax(), rect.ymax()));
		}
	}

	public Iterable<Point2D> range(RectHV rect) {
		requireNonNull(rect);
		List<Point2D> rtn = new LinkedList<>();
		find(this.root, rect, true, rtn);
		return rtn;
	}

	private static void find(Node n, RectHV rect, boolean byX, List<Point2D> rtn) {
		if (n == null)
			return;
		if (rect.contains(n.p)) {
			rtn.add(n.p);
			find(n.left, rect, !byX, rtn);
			find(n.right, rect, !byX, rtn);
			return;
		}

		if (byX) {
			if (rect.xmin() <= n.p.x() && n.p.x() <= rect.xmax()) {
				find(n.left, rect, !byX, rtn);
				find(n.right, rect, !byX, rtn);
			} else if (rect.xmax() < n.p.x()) {
				find(n.left, rect, !byX, rtn);
			} else {
				find(n.right, rect, !byX, rtn);
			}
		} else {
			if (rect.ymin() <= n.p.y() && n.p.y() <= rect.ymax()) {
				find(n.left, rect, !byX, rtn);
				find(n.right, rect, !byX, rtn);
			} else if (rect.ymax() < n.p.y()) {
				find(n.left, rect, !byX, rtn);
			} else {
				find(n.right, rect, !byX, rtn);
			}
		}
	}

	public Point2D nearest(Point2D p) {
		requireNonNull(p);
		FindNearestData data = new FindNearestData();
		findNearest(this.root, p, true, data, new RectHV(0, 0, 1, 1));
		Point2D nearest = data.nearest == null ? null : data.nearest.p;
		return nearest;
	}

	private void findNearest(Node n, Point2D p, boolean byX, FindNearestData data, RectHV findIn) {
//		System.out.println("findNearest " + n);
		if (n == null)
			return;

		double disToLine = byX ? n.p.x() - p.x() : n.p.y() - p.y();
		disToLine = Math.abs(disToLine);

		if (data.minDis >= disToLine) {
			double disToP = n.p.distanceTo(p);
			if (disToP < data.minDis) {
				data.minDis = disToP;
				data.nearest = n;
			}
			if (data.minDis == 0)
				return;
		}

		boolean goLeftFirst = (byX && p.x() < n.p.x()) || (!byX && p.y() < n.p.y());
		if (goLeftFirst) {
			if (n.left != null) {
				findNearest(n.left, p, !byX, data, nextRect(findIn, n, byX, true));
			}
			if (n.right != null && data.minDis > disToLine) {
				RectHV nextRect = nextRect(findIn, n, byX, false);
				if (data.minDis > nextRect.distanceTo(p))
					findNearest(n.right, p, !byX, data, nextRect);
			}
		} else {
			if (n.right != null) {
				findNearest(n.right, p, !byX, data, nextRect(findIn, n, byX, false));
			}
			if (n.left != null && data.minDis > disToLine) {
				RectHV nextRect = nextRect(findIn, n, byX, true);
				if (data.minDis > nextRect.distanceTo(p))
					findNearest(n.left, p, !byX, data, nextRect);
			}
		}
	}

	private static RectHV nextRect(RectHV curFindIn, Node curNode, boolean curByX, boolean goLeft) {
		if (curByX) {
			if (goLeft)
				return new RectHV(curFindIn.xmin(), curFindIn.ymin(), curNode.p.x(), curFindIn.ymax());
			else
				return new RectHV(curNode.p.x(), curFindIn.ymin(), curFindIn.xmax(), curFindIn.ymax());
		} else {
			if (goLeft)
				return new RectHV(curFindIn.xmin(), curFindIn.ymin(), curFindIn.xmax(), curNode.p.y());
			else
				return new RectHV(curFindIn.xmin(), curNode.p.y(), curFindIn.xmax(), curFindIn.ymax());
		}
	}

	private static class FindNearestData {
		Node nearest = null;
		double minDis = Double.MAX_VALUE;
	}

	private static void requireNonNull(Object o) {
		if (o == null)
			throw new java.lang.IllegalArgumentException();
	}

	private static class Node {
		Node left;
		Node right;
		Point2D p;
//		int index = -1;

		Node(Point2D p) {
			this.p = p;
		}

		@Override
		public String toString() {
			return String.valueOf(p);
		}

	}

}