Skip to content

Instantly share code, notes, and snippets.

@JosephCatrambone
Last active March 7, 2018 05:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JosephCatrambone/608a45ffaa0f2e2e9ecde74fac3102ec to your computer and use it in GitHub Desktop.
Save JosephCatrambone/608a45ffaa0f2e2e9ecde74fac3102ec to your computer and use it in GitHub Desktop.
A small octree implementation in pure python that supports arbitrary point dimensions and uses naive splitting.
#!/usr/bin/env python
#author: Joseph Catrambone
"""
Copyright 2018 Joseph Catrambone
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
class Octree(object):
def __init__(self, data_access_method=lambda pt, axis: pt[axis], bucket_capacity=100, dims=3):
"""data_access_method is a function taking a point and a dimension/axis.
:parameter data_access_method a lambda with two parameters, a point and an axis.
Returns the point's value on that axis. Example: p = [1, 2, 3, 4, 5]. axis=2. data_access_method(p, axis) = 3
:parameter bucket_capacity The maximum number of points to be stored on a leaf node before splitting.
:parameter dims The dimensionality of the point data. A 3D point will have dims=3.
"""
self.max_capacity = bucket_capacity
self.values = list()
self.parent = None
self.left = None
self.right = None
self.split_point = None
self.split_axis = 0
self.data_access_method = data_access_method
self.dims = dims
self.bounds = [[0.0] * self.dims, [0.0] * self.dims] # Min/Max
@property
def has_split(self):
return self.split_point is not None
def contains_volume(self, minimum, maximum):
"""Minimum and maximum should be an array of values each with len == this.dim"""
for axis in range(self.dims):
if self.bounds[0][axis] > minimum[axis]:
return False
if self.bounds[1][axis] < maximum[axis]:
return False
return True # This does contain the volume.
def add(self, point):
# This point is a child, so track it.
self.values.append(point)
# Update our bounds.
for axis in range(self.dims):
v = self.data_access_method(point, axis)
if self.bounds[0][axis] > v: # Lower
self.bounds[0][axis] = v
if self.bounds[1][axis] < v: # Upper
self.bounds[1][axis] = v
# Has this tree already been split?
if self.has_split:
# Find the subtree in which this should be inserted.
if self.data_access_method(point, self.split_axis) < self.split_point:
self.left.add(point)
else:
self.right.add(point)
else: # We have not split.
if len(self.values) > self.max_capacity: # Split.
# Find the mean value across the points and split there.
# TODO: Pick a better axis, rather than just rolling through.
mean = 0.0
for pt in self.values:
mean += self.data_access_method(pt, self.split_axis)
mean /= len(self.values)
self.split_point = mean # Makes this tree split.
self.left = Octree(self.data_access_method, self.max_capacity, self.dims)
self.right = Octree(self.data_access_method, self.max_capacity, self.dims)
self.left.split_axis = (self.split_axis+1)%self.dims
self.right.split_axis = (self.split_axis + 1) % self.dims
self.left.parent = self
self.right.parent = self
# NOTE: Assumes no degenerate case with points on top of each other.
# Call recursively.
temp_values = self.values
self.values = list() # Empty this out.
for pt in temp_values:
self.add(pt)
def distance_squared(self, p1, p2):
"""Calculates the euclidean distance between two points."""
accum = 0.0
for axis in range(self.dims):
delta = self.data_access_method(p2, axis) - self.data_access_method(p1, axis)
accum += delta*delta
return accum
def find_nearby(self, point, distance):
# First, find the point in the smallest bucket,
# then repeatedly step into the parent until the bounds exceed the distance.
bucket = self.find_in_tree(point)
volume_min = [self.data_access_method(point, axis)-distance for axis in range(self.dims)]
volume_max = [self.data_access_method(point, axis)+distance for axis in range(self.dims)]
while bucket is not None:
# Are the bounds of this bucket enough to contain the point less the distance in each direction.
if bucket.contains_volume(volume_min, volume_max) or bucket.parent is None:
# Indeed, this bucket has all the points we could need.
# TODO: Filter the points
to_return = list()
dist_sq = distance*distance
for p in bucket.values:
if bucket.distance_squared(point, p) < dist_sq:
to_return.append(p)
return to_return
else:
# Nope. This bucket doesn't have the bounds. Go up one level if we can.
bucket = bucket.parent
def find_in_tree(self, point): # Returns the tree in which this point might be found.
if not self.has_split:
return self
else:
v = self.data_access_method(point, self.split_axis)
if v < self.split_point:
return self.left.find_in_tree(point)
else:
return self.right.find_in_tree(point)
if __name__ == "__main__":
print("Running octree as test.")
t = Octree()
all_points = list()
from random import random
def make_point(scale):
return [random()*scale, random()*scale, random()*scale]
for _ in range(100000):
p = make_point(100)
all_points.append(p)
t.add(p)
# Finished adding data. Find nearby data.
p = make_point(100)
q_dist = 2
nearby = t.find_nearby(p, q_dist)
# Assure that all are less than 1^2 unit away
for q in nearby:
dist = (p[0]-q[0])**2 + (p[1]-q[1])**2 + (p[2]-q[2])**2
if dist > q_dist*q_dist:
raise Exception("Found point in octree greater than one unit away: p: {} q: {}".format(p, q))
print("Found {} points near {}".format(len(nearby), p))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment