Skip to content

Instantly share code, notes, and snippets.

@jamesgregson
Last active September 24, 2019 14:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesgregson/20d9dfb4f7f49c33a90bf27324c00e91 to your computer and use it in GitHub Desktop.
Save jamesgregson/20d9dfb4f7f49c33a90bf27324c00e91 to your computer and use it in GitHub Desktop.
Simple KDTree (3D)
#ifndef __KD_TREE_HEADER_H
#define __KD_TREE_HEADER_H
#include <set>
#include <list>
#include <tuple>
#include <queue>
#include <vector>
#include <algorithm>
namespace graphics {
struct kdtree_bounds {
float lo[3];
float hi[3];
kdtree_bounds box_union( const kdtree_bounds& in ) const {
return {
{ std::min(lo[0],in.lo[0]), std::min(lo[1],in.lo[1]), std::min(lo[2],in.lo[2]) },
{ std::max(hi[0],in.hi[0]), std::max(hi[1],in.hi[1]), std::max(hi[2],in.hi[2]) }
};
}
kdtree_bounds box_intersection( const kdtree_bounds& in ) const {
return {
{ std::max(lo[0],in.lo[0]), std::max(lo[1],in.lo[1]), std::max(lo[2],in.lo[2]) },
{ std::min(hi[0],in.hi[0]), std::min(hi[1],in.hi[1]), std::min(hi[2],in.hi[2]) }
};
}
bool intersects( const kdtree_bounds& in ) const {
return hi[0] >= in.lo[0] && lo[0] <= in.hi[0]
&& hi[1] >= in.lo[1] && lo[1] <= in.hi[1]
&& hi[2] >= in.lo[2] && lo[2] <= in.hi[2];
}
};
struct kdtree_node {
int split_axis;
float split_coord;
int children = -1;
std::vector<int> items;
};
class kdtree {
public:
template< typename ItemList >
kdtree( const size_t num_items, const ItemList& items, int max_items=20, int max_levels=25 ){
// get the union of all the input boxes
kdtree_bounds bnd = items[0];
for( auto i=0; i<num_items; ++i ){
bnd = bnd.box_union( items[i] );
}
// pre-allocate all the nodes
m_nodes.resize( 1<<max_levels );
for( auto i=0; i<num_items; ++i ){
m_nodes[0].items.push_back(i);
}
// split heuristic, split longest edge. this SUCKS for raytracing.
auto get_split = [&,this]( kdtree_bounds bnds, const kdtree_node& node ){
const float d[] = {bnds.hi[0]-bnds.lo[0],bnds.hi[1]-bnds.lo[1],bnds.hi[2]-bnds.lo[2]};
int best = 0;
best = d[1] > d[best] ? 1 : best;
best = d[2] > d[best] ? 2 : best;
return std::make_tuple(best,bnds.lo[best]+d[best]*0.5f);
};
// initialize a two queues
int next=1, nid, cid, axis;
float split;
kdtree_bounds nbnds, bnd0, bnd1;
std::list<std::tuple<int,kdtree_bounds>> queue, queue_new;
// main subdivision loop, process level by level
const size_t N = m_nodes.size();
queue.push_back(std::make_tuple(0,bnd));
for( auto level=0; level<max_levels-1; ++level ){
// subdivision for level
while( !queue.empty() ){
// retrieve the next node id and bounding box
std::tie(nid,nbnds) = queue.front(); queue.pop_front();
if( next >= N-2 || m_nodes[nid].items.size() <= max_items )
continue;
// compute the split location
cid = next; next += 2;
std::tie(axis,split) = get_split( nbnds, m_nodes[nid] );
m_nodes[nid].split_axis = axis;
m_nodes[nid].split_coord = split;
m_nodes[nid].children = cid;
m_nodes[cid+0].children = 0;
m_nodes[cid+1].children = 0;
// add the items to the children
for( auto idx : m_nodes[nid].items ){
if( items[idx].lo[axis] <= split ){
m_nodes[cid+0].items.push_back( idx );
}
if( items[idx].hi[axis] >= split ){
m_nodes[cid+1].items.push_back( idx );
}
}
// remove all the items from the split node
m_nodes[nid].items.clear();
// clip the bounds and add them to the queue
bnd0 = nbnds;
bnd0.hi[axis] = split;
bnd1 = nbnds;
bnd1.lo[axis] = split;
queue_new.push_back(std::make_tuple(cid+0,bnd0));
queue_new.push_back(std::make_tuple(cid+1,bnd1));
}
// swap the queue contents
std::swap( queue, queue_new );
}
}
template< typename distance_func >
std::tuple<float,int> query_closest( distance_func& dis_fn, const kdtree_bounds& bnd ){
// degenerate case of no children for root node....c'mon.
if( m_nodes[0].children <= 0 ){
return dis_fn( m_nodes[0].items, bnd );
}
int axis,nid,best,tmp;
float mind=1e10f,d,d0,d1,split;
// std::priority_queue<
// std::tuple<float,int>,
// std::vector<std::tuple<float,int>>,
// std::greater<std::tuple<float,int>> > queue;
std::set<std::tuple<float,int>> queue;
axis = m_nodes[0].split_axis;
split = m_nodes[0].split_coord;
d0 = std::max( 0.0f, bnd.lo[axis]-split );
d1 = std::max( 0.0f, split-bnd.hi[axis] );
if( d0 <= d1 ){
//queue.push(std::make_tuple(d1*d1,m_nodes[0].children+1) );
queue.insert(std::make_tuple(d0*d0,m_nodes[0].children+0) );
queue.insert(std::make_tuple(d1*d1,m_nodes[0].children+1) );
} else {
queue.insert(std::make_tuple(d1*d1,m_nodes[0].children+1) );
queue.insert(std::make_tuple(d0*d0,m_nodes[0].children+0) );
}
// search...
const size_t N = m_nodes.size();
while( !queue.empty() ){
// get the next closest entry
std::tie(d,nid) = *queue.begin();
queue.erase(queue.begin());
// std::tie(d,nid) = queue.top();
// queue.pop();
// distance is greater than or equal to
// minimum possible distance, closest
// point has already been found!
if( nid >= N || d >= mind )
break;
if( m_nodes[nid].children <= 0 ){
// node does not have children, get the
// minimum distance squared to all items
// within the node
std::tie(d,tmp) = dis_fn( m_nodes[nid].items, bnd );
if( d < mind ){
best = tmp;
mind = d;
}
} else {
// node does have children, add them to
// the queue in increasing order of distance
axis = m_nodes[nid].split_axis;
split = m_nodes[nid].split_coord;
d0 = std::max( 0.0f, bnd.lo[axis]-split );
d1 = std::max( 0.0f, split-bnd.hi[axis] );
if( d0 <= d1 ){
queue.insert(std::make_tuple(d0*d0,m_nodes[nid].children+0) );
queue.insert(std::make_tuple(d1*d1,m_nodes[nid].children+1) );
} else {
queue.insert(std::make_tuple(d1*d1,m_nodes[nid].children+1) );
queue.insert(std::make_tuple(d0*d0,m_nodes[nid].children+0) );
}
}
}
return std::make_tuple(mind,best);
}
private:
std::vector<kdtree_node> m_nodes;
};
};
#endif
#include "kd_tree.h"
#include <ctime>
#include <cstdlib>
#include <iostream>
#include <stdexcept>
#include <sys/time.h>
double curr_time(){
struct timeval tv;
gettimeofday( &tv, NULL );
return double(tv.tv_sec) + 1e-6*double(tv.tv_usec);
}
graphics::kdtree_bounds make_bounds( float x, float y, float z ){
return {{x,y,z},{x,y,z}};
}
std::ostream& operator<<( std::ostream& os, const graphics::kdtree_bounds& bnd ){
os << "[" << bnd.lo[0] << ", " << bnd.lo[1] << ", " << bnd.lo[2] << "]";
return os;
}
class point_set {
public:
point_set( const std::vector<float>& pnts ) : m_pnts(pnts) {
}
size_t size() const {
return m_pnts.size()/3;
}
graphics::kdtree_bounds operator[]( const size_t &idx ) const {
return make_bounds( m_pnts[idx*3+0], m_pnts[idx*3+1], m_pnts[idx*3+2] );
}
private:
const std::vector<float> &m_pnts;
};
int main( int argc, char **argv ){
std::vector<int> items;
std::vector<float> pnts;
for( auto i=0; i<100000; i++ ){
pnts.push_back( drand48() );
pnts.push_back( drand48() );
pnts.push_back( drand48() );
items.push_back(i);
}
auto ps = point_set(pnts);
auto min_dis_func = [&]( const std::vector<int>& items, const graphics::kdtree_bounds& bnd ){
int best;
float dis, min_dis = 1e10f;
for( auto idx : items ){
float delta[] = { pnts[idx*3+0]-bnd.lo[0], pnts[idx*3+1]-bnd.lo[1], pnts[idx*3+2]-bnd.lo[2] };
dis = delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2];
if( dis < min_dis ){
min_dis = dis;
best = idx;
}
}
return std::make_tuple( min_dis, best );
};
graphics::kdtree kdtree( pnts.size()/3, point_set(pnts) );
std::cout << "done building tree..." << std::endl;
srand(time(NULL));
srand48(time(NULL));
int best_kd, best_bf;
float min_dis_kd, min_dis_bf, total_dis;
for( auto i=0; i<1000; ++i ){
auto bnd = make_bounds(drand48(),drand48(),drand48());
std::tie(min_dis_kd,best_kd) = kdtree.query_closest( min_dis_func, bnd );
std::tie(min_dis_bf,best_bf) = min_dis_func( items, bnd );
if( best_kd != best_bf ){
std::cout << best_kd << " " << best_bf << std::endl;
std::cout << min_dis_kd << " " << min_dis_bf << std::endl;
}
}
int N = 100000;
std::vector<float> x,y,z;
for( auto i=0; i<N; ++i ){
x.push_back(drand48());
y.push_back(drand48());
z.push_back(drand48());
}
total_dis = 0.0f;
double t = curr_time();
for( auto i=0; i<N; ++i ){
auto bnd = make_bounds(x[i],y[i],z[i]);
std::tie(min_dis_kd,best_kd) = kdtree.query_closest( min_dis_func, bnd );
total_dis += min_dis_kd;
}
std::cout << "Total time: " << (curr_time()-t)*1e6/double(N) << "us/item" << std::endl;
std::cout << "Total distance: " << total_dis << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment