Skip to content

Instantly share code, notes, and snippets.

@brendancol
Last active November 9, 2018 16:42
Show Gist options
  • Save brendancol/a3dd4a35ecd94660411112999923d561 to your computer and use it in GitHub Desktop.
Save brendancol/a3dd4a35ecd94660411112999923d561 to your computer and use it in GitHub Desktop.
Distributed KDTree using Dask
from functools import partial
import numpy as np
from scipy.spatial import cKDTree
from dask.delayed import delayed
import dask.bag as db
class DaskKDTree(object):
"""
Usage Example:
--------------
from dask.distributed import Client
client = Client('52.91.203.58:8786')
tree = DaskKDTree(client, leafsize=1000)
tree.load_random(num_points=int(1e8), chunk_size=int(1e6))
# find all points within 10km of Bourbon Street
bourbon_street = (-10026195.958134, 3498018.476606)
radius = 10000 # meters
result = tree.query_ball_point(x=bourbon_street, r=radius)
"""
def __init__(self, client, leafsize):
self.client = client
self.leafsize = leafsize
self.trees = []
def load_random(self, num_points=int(1e6), chunk_size=300):
parts = int(num_points / chunk_size)
self.trees = [delayed(DaskKDTree._run_load_random)(int(chunk_size), leafsize=self.leafsize) for f in range(parts)]
self.trees = self.client.persist(self.trees)
@staticmethod
def _run_load_random(count, leafsize):
xs = np.random.uniform(int(-20e6), int(20e6), count)
ys = np.random.uniform(int(-20e6), int(20e6), count)
points = np.dstack((xs, ys))[0, :]
return cKDTree(points, leafsize=leafsize)
def query_ball_point(self, **kwargs):
nearest = [delayed(DaskKDTree._run_query_ball_point)(d, kwargs) for d in self.trees]
b = db.from_delayed(nearest)
return b.compute()
@staticmethod
def _run_query_ball_point(tree, query_info):
indices = tree.query_ball_point(**query_info)
return tree.data[indices]
@mrocklin
Copy link

My understanding here is that the approach is to have a flat collection of single-machine kd-trees and, when we need to query something we check them all. Is this right?

I would want to know more about how this is likely to be used, and other possible options. My guess is that people have analyzed other possible arrangements, such as where we partition the dataset ahead of time so that different sections of a large tree are on different machines. I suspect that there is some tradeoff between the two (or more) possible cases.

Small style feedback, you might want to avoid use of partial here:

    query_obj = partial(DaskKDTree._run_query_ball_point, query_info=kwargs)
    nearest = [delayed(query_obj)(d) for d in self.trees]

Instead doing something like the following:

    nearest = [delayed(DaskKDTree._run_query_ball_point)(d, query_info=kwargs) for d in self.trees]

This is both somewhat more direct and easier for Dask to serialize.

@munitech4u
Copy link

what's the approximate time to complete the nearest search on a computer with Ram 16GB and i7 processor. For single point it was running for more than 15 minutes. Is there something I might be doing wrong? I want to run this query_ball_point for millions of points

@munitech4u
Copy link

Facing this error, while replicating your example:
File "C:\Anaconda3\lib\site-packages\dask\bag\core.py", line 1460, in reify
if seq and isinstance(seq[0], Iterator):

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

@pelson
Copy link

pelson commented Jul 31, 2018

For small number of kdtrees, this is an interesting approach.

For truly distributed KDTrees, I recently read a paper "Highly Parallel Fast KD‐tree Construction for Interactive Ray Tracing of Dynamic Scenes" [1] that looks like it might be a more optimal way to do this in parallel.

[1]: https://onlinelibrary.wiley.com/doi/pdf/10.1111/j.1467-8659.2007.01062.x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment