Skip to content

Instantly share code, notes, and snippets.

@brendancol
Last active November 9, 2018 16:42
Show Gist options
  • Save brendancol/a3dd4a35ecd94660411112999923d561 to your computer and use it in GitHub Desktop.
Save brendancol/a3dd4a35ecd94660411112999923d561 to your computer and use it in GitHub Desktop.
Distributed KDTree using Dask
from functools import partial
import numpy as np
from scipy.spatial import cKDTree
from dask.delayed import delayed
import dask.bag as db
class DaskKDTree(object):
"""
Usage Example:
--------------
from dask.distributed import Client
client = Client('52.91.203.58:8786')
tree = DaskKDTree(client, leafsize=1000)
tree.load_random(num_points=int(1e8), chunk_size=int(1e6))
# find all points within 10km of Bourbon Street
bourbon_street = (-10026195.958134, 3498018.476606)
radius = 10000 # meters
result = tree.query_ball_point(x=bourbon_street, r=radius)
"""
def __init__(self, client, leafsize):
self.client = client
self.leafsize = leafsize
self.trees = []
def load_random(self, num_points=int(1e6), chunk_size=300):
parts = int(num_points / chunk_size)
self.trees = [delayed(DaskKDTree._run_load_random)(int(chunk_size), leafsize=self.leafsize) for f in range(parts)]
self.trees = self.client.persist(self.trees)
@staticmethod
def _run_load_random(count, leafsize):
xs = np.random.uniform(int(-20e6), int(20e6), count)
ys = np.random.uniform(int(-20e6), int(20e6), count)
points = np.dstack((xs, ys))[0, :]
return cKDTree(points, leafsize=leafsize)
def query_ball_point(self, **kwargs):
nearest = [delayed(DaskKDTree._run_query_ball_point)(d, kwargs) for d in self.trees]
b = db.from_delayed(nearest)
return b.compute()
@staticmethod
def _run_query_ball_point(tree, query_info):
indices = tree.query_ball_point(**query_info)
return tree.data[indices]
@munitech4u
Copy link

what's the approximate time to complete the nearest search on a computer with Ram 16GB and i7 processor. For single point it was running for more than 15 minutes. Is there something I might be doing wrong? I want to run this query_ball_point for millions of points

@munitech4u
Copy link

Facing this error, while replicating your example:
File "C:\Anaconda3\lib\site-packages\dask\bag\core.py", line 1460, in reify
if seq and isinstance(seq[0], Iterator):

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

@pelson
Copy link

pelson commented Jul 31, 2018

For small number of kdtrees, this is an interesting approach.

For truly distributed KDTrees, I recently read a paper "Highly Parallel Fast KD‐tree Construction for Interactive Ray Tracing of Dynamic Scenes" [1] that looks like it might be a more optimal way to do this in parallel.

[1]: https://onlinelibrary.wiley.com/doi/pdf/10.1111/j.1467-8659.2007.01062.x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment