Skip to content

Instantly share code, notes, and snippets.

@fferri
Created January 4, 2017 23:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fferri/4f3160b4eb6f8fb1b137107f18ade2dc to your computer and use it in GitHub Desktop.
Save fferri/4f3160b4eb6f8fb1b137107f18ade2dc to your computer and use it in GitHub Desktop.
def dist(x, y):
return sum((xi-yi)**2 for xi,yi in zip(x,y))
class KDTree:
class Node:
def __init__(self, x, payload=None, axis=0):
self.x, self.axis, self.payload = x, axis, payload
self.l, self.r = None, None
def insert(self, x, payload=None):
if x[self.axis] <= self.x[self.axis]:
# should insert node in left subtree
if self.l is None:
self.l = KDTree.Node(x, payload, (self.axis + 1) % len(x))
else:
self.l.insert(x, payload)
else:
# should insert node in right subtree
if self.r is None:
self.r = KDTree.Node(x, payload, (self.axis + 1) % len(x))
else:
self.r.insert(x, payload)
def search(self, xmin, xmax, result):
if self.x[self.axis] >= xmin[self.axis] and self.x[self.axis] <= xmax[self.axis]:
if all(xmin[a] <= self.x[a] <= xmax[a] for a in range(len(xmin))):
result.append((self.x, self.payload))
if self.l is not None:
self.l.search(xmin, xmax, result)
if self.r is not None:
self.r.search(xmin, xmax, result)
def nearestNeighbour(self, x, best, bestDist):
if x[self.axis] <= self.x[self.axis]:
if self.l is not None:
d = dist(x, self.l.x)
if d < bestDist:
best = self.l
bestDist = d
return self.l.nearestNeighbour(x, best, bestDist)
else:
if self.r is not None:
d = dist(x, self.r.x)
if d < bestDist:
best = self.r
bestDist = d
return self.r.nearestNeighbour(x, best, bestDist)
return best.x, best.payload, bestDist
def __init__(self):
self.root = None
def insert(self, x, payload=None):
if self.root is None:
self.root = KDTree.Node(x, payload)
else:
self.root.insert(x, payload)
def radiusSearch(self, x, radius):
ret = []
if self.root is not None:
self.root.search([xi - radius for xi in x], [xi + radius for xi in x], ret)
return ret
def nearestNeighbour(self, x):
if self.root is not None:
return self.root.nearestNeighbour(x, self.root, dist(self.root.x, x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment