Created December 5, 2022 06:46
KD-Tree query using PyOpenCL to find k nearest neighbors
import numpy as np
import pyopencl as cl
from pymatting import KDTree
import pymatting
import time
source = """
typedef int int32_t;
typedef long int64_t;
__kernel void _tree_query_cl(
__global const int64_t *i0_inds,
__global const int64_t *i1_inds,
__global const int64_t *less_inds,
__global const int64_t *more_inds,
__global const int64_t *split_dims,
__global const float *bounds,
__global const float *split_values,
__global const float *points,
__global const float *query_points,
__global int64_t *out_indices,
__global float *out_distances,
int32_t k,
int32_t n_query,
int32_t dimension
int i_query = get_global_id(0);
if (i_query >= n_query) return;
__global const float *query_point = query_points + i_query * dimension;
__global float *distances = out_distances + i_query * k;
__global int64_t *indices = out_indices + i_query * k;
int64_t stack[100];
int n_neighbors = 0;
stack[0] = 0;
int stack_size = 1;
// While there are nodes to visit
while (stack_size > 0){
int i_node = stack[stack_size];
// If we found enough neighbors
if (n_neighbors >= k){
float dist = 0.0f;
for (int d = 0; d < dimension; d++){
float p = query_point[d];
// bounds shape is (n_data, 2, dimension)
float lower_bound = bounds[i_node * dimension * 2 + 0 * dimension + d];
float upper_bound = bounds[i_node * dimension * 2 + 1 * dimension + d];
float dp = p - fmax(lower_bound, fmin(p, upper_bound));
dist += dp * dp;
// Do nothing with this node if all points we have found so far
// are closer than the bounding box of the node.
if (dist > distances[n_neighbors - 1]){
// If we are at a leaf
if (split_dims[i_node] == -1){
// For each point in leaf node
for (int i = i0_inds[i_node]; i < i1_inds[i_node]; i++){
float distance = 0.0f;
for (int d = 0; d < dimension; d++){
float dd = query_point[d] - points[i * dimension + d];
distance += dd * dd;
// Find insert position
int insert_pos = n_neighbors;
for (int j = n_neighbors - 1; j >= 0; j--){
if (distances[j] > distance) insert_pos = j;
else break;
if (insert_pos < k){
// Move [insert_pos:k-1] one to the right to make space
int j = k - 1;
if (j > n_neighbors) j = n_neighbors;
for (; j > insert_pos; j--){
distances[j] = distances[j - 1];
indices[j] = indices[j - 1];
// Insert new neighbors
indices[insert_pos] = i;
distances[insert_pos] = distance;
if (n_neighbors > k) n_neighbors = k;
// Descent to child nodes
int64_t split_dim = split_dims[i_node];
int64_t less = less_inds[i_node];
int64_t more = more_inds[i_node];
if (query_point[split_dim] < split_values[i_node]){
stack[stack_size++] = more;
stack[stack_size++] = less;
stack[stack_size++] = less;
stack[stack_size++] = more;
platform = cl.get_platforms()[0]
devices = platform.get_devices(cl.device_type.GPU)
if not devices:
print("WARNING: OpenCL could not find any GPU device. Trying other devices.")
devices = platform.get_devices(cl.device_type.ALL)
assert len(devices) > 0, "Could not find any OpenCL-capable device"
device = devices[0]
context = cl.Context([device])
queue = cl.CommandQueue(context)
program = cl.Program(context, source).build()
def upload(array):
hostbuf = array.astype(array.dtype).flatten()
return cl.Buffer(
cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR,
def download(device_buf, shape, dtype):
host_buf = np.empty(shape, dtype=dtype)
cl.enqueue_copy(queue, host_buf, device_buf)
return host_buf.reshape(shape)
def tree_query_cl(tree, query_points, k):
assert query_points.dtype == np.float32
assert query_points.shape[1] == tree.shuffled_data_points.shape[1]
n_query, dimension = query_points.shape
squared_distances = np.empty((n_query, k), np.float32)
indices = np.empty((n_query, k), np.int64)
gpu_i0_inds = upload(tree.i0_inds)
gpu_i1_inds = upload(tree.i1_inds)
gpu_less_inds = upload(tree.less_inds)
gpu_more_inds = upload(tree.more_inds)
gpu_split_dims = upload(tree.split_dims)
gpu_bounds = upload(tree.bounds)
gpu_split_values = upload(tree.split_values)
gpu_shuffled_data_points = upload(tree.shuffled_data_points)
gpu_query_points = upload(query_points)
gpu_indices = upload(indices)
gpu_squared_distances = upload(squared_distances)
indices = download(gpu_indices, (n_query, k), np.int64)
squared_distances = download(gpu_squared_distances, (n_query, k), np.float32)
for buf in [
gpu_i0_inds, gpu_i1_inds, gpu_less_inds, gpu_more_inds,
gpu_split_dims, gpu_bounds, gpu_split_values, gpu_shuffled_data_points,
gpu_query_points, gpu_indices, gpu_squared_distances,
indices = tree.shuffled_indices[indices]
distances = np.sqrt(squared_distances)
return distances, indices
def main():
k = 20
data_points = np.random.rand(256 * 512, 3).astype(np.float32)
query_points = np.random.rand(256 * 512, 3).astype(np.float32)
t = time.perf_counter()
tree = pymatting.KDTree(data_points)
dt = time.perf_counter() - t
print("\nbuild KD tree:", dt * 1000, "ms\n")
for _ in range(10):
t = time.perf_counter()
expected_distances, expected_indices = tree.query(query_points, k=k)
dt = time.perf_counter() - t
print("numba", dt * 1000, "ms")
t = time.perf_counter()
distances, indices = tree_query_cl(tree, query_points, k=k)
dt = time.perf_counter() - t
print("opencl", dt * 1000, "ms\n")
mean_squared_error = np.mean(np.square(distances - expected_distances))
assert mean_squared_error < 1e-10
assert np.allclose(distances, expected_distances)
if __name__ == "__main__":
