Last active
January 23, 2017 11:19
-
-
Save sang4lv/b1974d71fea6ffeb379f7f2eeeb306d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Uses python3 | |
import sys | |
import math | |
class Node: | |
def __init__(self, coords): | |
self.left = None | |
self.right = None | |
self.coords = coords | |
class KDTree: | |
def __init__(self, k): | |
self.k = k | |
self.root = None | |
def insert(self, coords): | |
node = Node(coords) | |
if self.root is None: | |
self.root = node | |
return | |
dim = 0 | |
curr = self.root | |
while True: | |
if node.coords[dim] <= curr.coords[dim]: | |
if curr.left is None: | |
curr.left = node | |
break | |
curr = curr.left | |
else: | |
if curr.right is None: | |
curr.right = node | |
break | |
curr = curr.right | |
dim = (dim + 1) % self.k | |
def show(self, curr=None): | |
curr = curr or self.root | |
print curr.coords | |
if curr.left is not None: | |
print "Left subtree:" | |
self.show(curr.left) | |
elif curr.right is not None: | |
print "Right subtree:" | |
self.show(curr.right) | |
def get_sum_square(self, pair): | |
return (pair[1] - pair[0])**2 | |
def get_axes_square_sum(self, coord1, coord2): | |
return reduce(lambda total, axis: total + self.get_sum_square(axis), zip(coord1, coord2), 0) | |
def get_minimum_distance(self, coords, node, dim, best, visited): | |
hyper_distance = self.get_sum_square([node.coords[dim], coords[dim]]) | |
if hyper_distance <= best: | |
dim = (dim - 1) % self.k | |
best = min(self.get_axes_square_sum(node.coords, coords), best) | |
next_node = None | |
if node.left and node.left not in visited: | |
next_node = node.left | |
elif node.right and node.right not in visited: | |
next_node = node.right | |
if next_node: | |
visited.append(next_node) | |
return self.get_minimum_distance(coords, next_node, dim, best, visited) | |
return best | |
def findNN(self, coords): | |
dim = 0 | |
curr = self.root | |
visited = [] | |
breadcrumb = [self.root] | |
# find the leaf where coords should be inserted | |
while True: | |
if coords[dim] <= curr.coords[dim]: | |
if curr.left is None: | |
break | |
curr = curr.left | |
else: | |
if curr.right is None: | |
break | |
curr = curr.right | |
dim = (dim + 1) % self.k | |
visited.append(curr) | |
breadcrumb.append(curr) | |
# assume last node is the best | |
curr = breadcrumb.pop() | |
best_distance = self.get_axes_square_sum(curr.coords, coords) | |
dim = (dim - 1) % self.k | |
# backtrace current path in hopes of finding NN | |
while len(breadcrumb): | |
curr = breadcrumb.pop() | |
best_distance = self.get_minimum_distance(coords, curr, dim, best_distance, visited) | |
dim = (dim - 1) % self.k | |
# bail early | |
if best_distance == 0.0: | |
return best_distance | |
return best_distance | |
def minimum_distance(x, y): | |
tree = KDTree(2) | |
for pair in zip(x, y): | |
tree.insert(list(pair)) | |
result = tree.findNN([6, 7]) | |
print math.sqrt(result) | |
x = [7, 1, 4, 7, 8, 5] | |
y = [7, 100, 8, 7, 6, 101] | |
minimum_distance(x, y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment