Created
January 4, 2017 23:49
-
-
Save fferri/4f3160b4eb6f8fb1b137107f18ade2dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def dist(x, y): | |
return sum((xi-yi)**2 for xi,yi in zip(x,y)) | |
class KDTree: | |
class Node: | |
def __init__(self, x, payload=None, axis=0): | |
self.x, self.axis, self.payload = x, axis, payload | |
self.l, self.r = None, None | |
def insert(self, x, payload=None): | |
if x[self.axis] <= self.x[self.axis]: | |
# should insert node in left subtree | |
if self.l is None: | |
self.l = KDTree.Node(x, payload, (self.axis + 1) % len(x)) | |
else: | |
self.l.insert(x, payload) | |
else: | |
# should insert node in right subtree | |
if self.r is None: | |
self.r = KDTree.Node(x, payload, (self.axis + 1) % len(x)) | |
else: | |
self.r.insert(x, payload) | |
def search(self, xmin, xmax, result): | |
if self.x[self.axis] >= xmin[self.axis] and self.x[self.axis] <= xmax[self.axis]: | |
if all(xmin[a] <= self.x[a] <= xmax[a] for a in range(len(xmin))): | |
result.append((self.x, self.payload)) | |
if self.l is not None: | |
self.l.search(xmin, xmax, result) | |
if self.r is not None: | |
self.r.search(xmin, xmax, result) | |
def nearestNeighbour(self, x, best, bestDist): | |
if x[self.axis] <= self.x[self.axis]: | |
if self.l is not None: | |
d = dist(x, self.l.x) | |
if d < bestDist: | |
best = self.l | |
bestDist = d | |
return self.l.nearestNeighbour(x, best, bestDist) | |
else: | |
if self.r is not None: | |
d = dist(x, self.r.x) | |
if d < bestDist: | |
best = self.r | |
bestDist = d | |
return self.r.nearestNeighbour(x, best, bestDist) | |
return best.x, best.payload, bestDist | |
def __init__(self): | |
self.root = None | |
def insert(self, x, payload=None): | |
if self.root is None: | |
self.root = KDTree.Node(x, payload) | |
else: | |
self.root.insert(x, payload) | |
def radiusSearch(self, x, radius): | |
ret = [] | |
if self.root is not None: | |
self.root.search([xi - radius for xi in x], [xi + radius for xi in x], ret) | |
return ret | |
def nearestNeighbour(self, x): | |
if self.root is not None: | |
return self.root.nearestNeighbour(x, self.root, dist(self.root.x, x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment