{{ message }}

Instantly share code, notes, and snippets.

# tompaton/kdtree.py

Created Mar 10, 2011
Python kd-tree spatial index and nearest neighbour search
 #!/usr/bin/env python # kd-tree index and nearest neighbour search # includes doctests, run with: python -m doctest kdtree.py class KDTree(object): """ kd-tree spatial index and nearest neighbour search http://en.wikipedia.org/wiki/Kd-tree """ def __init__(self, point_list, _depth=0): """ Initialize kd-tree index with points. >>> KDTree([]) None >>> KDTree([(1,1)]) (0, (1, 1), None, None) >>> KDTree([(1,1),(2,2)]) (0, (2, 2), (1, (1, 1), None, None), None) >>> KDTree([(1,1),(2,2),(3,3)]) (0, (2, 2), (1, (1, 1), None, None), (1, (3, 3), None, None)) """ if point_list: # Select axis based on depth so that axis cycles through all valid values self.axis = _depth % len(point_list) # Sort point list and choose median as pivot element point_list = sorted(point_list, key=lambda point: point[self.axis]) median = len(point_list) // 2 # choose median # Create node and construct subtrees self.location = point_list[median] self.child_left = KDTree(point_list[:median], _depth + 1) self.child_right = KDTree(point_list[median + 1:], _depth + 1) else: self.axis = 0 self.location = None self.child_left = None self.child_right = None def closest_point(self, point, _best=None): """ Efficient recursive search for nearest neighbour to point >>> t = KDTree([(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]) >>> t (0, (7, 2), (1, (5, 4), (0, (2, 3), None, None), (0, (4, 7), None, None)), (1, (9, 6), (0, (8, 1), None, None), None)) >>> t.closest_point( (7,2) ) (7, 2) >>> t.closest_point( (8,1) ) (8, 1) >>> t.closest_point( (1,1) ) (2, 3) >>> t.closest_point( (5,5) ) (5, 4) """ if self.location is None: return _best if _best is None: _best = self.location # consider the current node if distance(self.location, point) < distance(_best, point): _best = self.location # search the near branch _best = self._child_near(point).closest_point(point, _best) # search the away branch - maybe if self._distance_axis(point) < distance(_best, point): _best = self._child_away(point).closest_point(point, _best) return _best # internal methods def __repr__(self): """ Simple representation for doctests """ if self.location: return "(%d, %s, %s, %s)" % (self.axis, repr(self.location), repr(self.child_left), repr(self.child_right)) else: return "None" def _distance_axis(self, point): """ Squared distance from current node axis to point >>> KDTree([(1,1)])._distance_axis((2,3)) 1 >>> KDTree([(1,1),(2,2)]).child_left._distance_axis((2,3)) 4 """ # project point onto node axis # i.e. want to measure distance on axis orthogonal to current node's axis axis_point = list(point) axis_point[self.axis] = self.location[self.axis] return distance(tuple(axis_point), point) def _child_near(self, point): """ Either left or right child, whichever is closest to the point """ if point[self.axis] < self.location[self.axis]: return self.child_left else: return self.child_right def _child_away(self, point): """ Either left or right child, whichever is furthest from the point """ if self._child_near(point) is self.child_left: return self.child_right else: return self.child_left # helper function def distance(a, b): """ Squared distance between points a & b """ return (a-b)**2 + (a-b)**2

### pushkarprasad007 commented Mar 19, 2014

 The distance function works only if the points in kd-tree are of 2 dimensions. However, the code does not assume the points in point_list to be of 2 dim. Simply changing the global distance function would fix the issue.

### alqahtani-abdulaziz commented Dec 6, 2016

 Hi There, Thanks for the code. It really helped me. One thing though is that the closest_point function always returns the same original point passed to it .. any clues to why is that happening ?

### Couhp commented Dec 6, 2018

 thanks for the code it s really helpful