Skip to content

Instantly share code, notes, and snippets.

@juaxix
Last active March 9, 2022 19:38
Show Gist options
  • Save juaxix/458b9a61654803017bcbe249a51d4be7 to your computer and use it in GitHub Desktop.
Save juaxix/458b9a61654803017bcbe249a51d4be7 to your computer and use it in GitHub Desktop.
C++ KD Tree with Vector3 and a condition flag
#include <iostream>
#include <memory>
#include <math.h>
#include <algorithm>
#include <vector>
#include <string>
using namespace std;
void padTo(string &str, const size_t num, const char paddingChar = ' ')
{
if(num > str.size())
str.insert(0, num - str.size(), paddingChar);
}
struct Vector3{
public:
float X;
float Y;
float Z;
static float Distance(const Vector3 &v1,const Vector3& v2)
{
return (float)sqrt
(
(v1.X - v2.X) * (v1.X - v2.X) +
(v1.Y - v2.Y) * (v1.Y - v2.Y) +
(v1.Z - v2.Z) * (v1.Z - v2.Z)
);
};
float operator[] (int i){
switch(i){
default: case 0: return X;
case 1: return Y;
case 2: return Z;
}
};
Vector3 operator-(const Vector3& b){
Vector3 r;
r.X = X-b.X; r.Y=Y-b.Y; r.Z=Z-b.Z;
return r;
};
bool operator==(const Vector3& b){
return X==b.X&&Y==b.Y&&Z==b.Z;
}
float sqrMagnitude(){
return X * X + Y * Y + Z * Z;
};
static Vector3 zero(){
static Vector3 z{0,0,0};
return z;
};
string ToString(){
return "("+to_string(X)+","+to_string(Y)+","+to_string(Z)+")";
}
};
struct TreeNode{
public:
Vector3 position;
bool healed;
};
class KDTree
{
private:
static int callcounter;
public:
static TreeNode lastPivot;
vector<shared_ptr<KDTree>> lr;
TreeNode pivot;
int pivotIndex;
int axis;
// Change this value to 3 if you need three-dimensional X,Y,Z points. The search will be quicker in two dimensions.
static const int numDims = 3;
KDTree() : lr()
{
lr.reserve(2);
}
// Make a new tree from a list of points.
static shared_ptr<KDTree> MakeFromPoints(vector<TreeNode>& points) {
vector<int> indices = Iota(points.size());
return MakeFromPointsInner(0, 0, points.size() - 1, points, indices);
}
// Recursively build a tree by separating points at plane boundaries.
static shared_ptr<KDTree> MakeFromPointsInner(
int depth,
int stIndex, int enIndex,
vector<TreeNode>& points,
vector<int>& inds
)
{
shared_ptr<KDTree> root = make_shared<KDTree>();
root->axis = depth % KDTree::numDims;
int splitPoint = FindPivotIndex(points, inds, stIndex, enIndex, root->axis);
root->pivotIndex = inds[splitPoint];
root->pivot = points[root->pivotIndex];
int leftEndIndex = splitPoint - 1;
if (leftEndIndex >= stIndex) {
root->lr[0] = MakeFromPointsInner(depth + 1, stIndex, leftEndIndex, points, inds);
}
int rightStartIndex = splitPoint + 1;
if (rightStartIndex <= enIndex) {
root->lr[1] = MakeFromPointsInner(depth + 1, rightStartIndex, enIndex, points, inds);
}
return root;
}
static void SwapElements(vector<int> &arr, int a, int b) {
int temp = arr[a];
arr[a] = arr[b];
arr[b] = temp;
}
// Simple "median of three" heuristic to find a reasonable splitting plane.
static int FindSplitPoint(vector<TreeNode>& points, const vector<int>& inds, int stIndex, int enIndex, int axis) {
float a = points[inds[stIndex]].position[axis];
float b = points[inds[enIndex]].position[axis];
int midIndex = (stIndex + enIndex) / 2;
float m = points[inds[midIndex]].position[axis];
if (a > b) {
if (m > a) {
return stIndex;
}
if (b > m) {
return enIndex;
}
return midIndex;
} else {
if (a > m) {
return stIndex;
}
if (m > b) {
return enIndex;
}
return midIndex;
}
}
// Find a new pivot index from the range by splitting the points that fall either side
// of its plane.
static int FindPivotIndex(vector<TreeNode>& points, vector<int>& inds, int stIndex, int enIndex, int axis) {
int splitPoint = FindSplitPoint(points, inds, stIndex, enIndex, axis);
// int splitPoint = Random.Range(stIndex, enIndex);
Vector3 pivot = points[inds[splitPoint]].position;
SwapElements(inds, stIndex, splitPoint);
int currPt = stIndex + 1;
int endPt = enIndex;
while (currPt <= endPt) {
Vector3 curr = points[inds[currPt]].position;
if ((curr[axis] > pivot[axis])) {
SwapElements(inds, currPt, endPt);
endPt--;
} else {
SwapElements(inds, currPt - 1, currPt);
currPt++;
}
}
return currPt - 1;
}
static vector<int> Iota(int num) {
vector<int> result(num);
for (int i = 0; i < num; i++) {
result[i] = i;
}
return result;
}
// Find the nearest point in the set to the supplied point.
int FindNearest(Vector3 pt) {
float bestSqDist = std::numeric_limits<float>::max();
int bestIndex = -1;
Search(pt, bestSqDist, bestIndex);
return bestIndex;
}
// Recursively search the tree.
void Search(Vector3& pt, float &bestSqSoFar, int &bestIndex)
{
float mySqDist = std::numeric_limits<float>::max();
if (!pivot.healed) {
mySqDist = (pivot.position - pt).sqrMagnitude();
}
if (mySqDist < bestSqSoFar) {
bestSqSoFar = mySqDist;
bestIndex = pivotIndex;
if (lastPivot.position==Vector3::zero() ||
Vector3::Distance(
pt,pivot.position
)<
Vector3::Distance(
pt,lastPivot.position
)
){
callcounter++;
lastPivot = TreeNode();
lastPivot.position=pivot.position;
lastPivot.healed=pivot.healed;
//Debug.Log(callcounter.ToString() + ": New position " + bestIndex.ToString()+ "; " + lastPivot.position.ToString());
}
}
float planeDist = pt[axis] - pivot.position[axis]; //DistFromSplitPlane(pt, pivot, axis);
int selector = planeDist <= 0 ? 0 : 1;
if (lr[selector] != nullptr) {
lr[selector]->Search(pt, bestSqSoFar, bestIndex);
}
selector = (selector + 1) % 2;
float sqPlaneDist = planeDist * planeDist;
if ((lr[selector] != nullptr) && (bestSqSoFar > sqPlaneDist)) {
lr[selector]->Search(pt, bestSqSoFar, bestIndex);
}
}
// Get a point's distance from an axis-aligned plane.
float DistFromSplitPlane(Vector3 pt, Vector3 planePt, int axis)
{
return pt[axis] - planePt[axis];
}
// Simple output of tree structure - mainly useful for getting a rough
// idea of how deep the tree is (and therefore how well the splitting
// heuristic is performing).
string Dump(int level)
{
string result = to_string(pivotIndex);
padTo(result,level);
result += "\n";
if (lr[0] != nullptr) {
result += lr[0]->Dump(level + 2);
}
if (lr[1] != nullptr) {
result += lr[1]->Dump(level + 2);
}
return result;
}
};
int KDTree::callcounter = 0;
TreeNode KDTree::lastPivot=TreeNode();
int main(int argc,char **argv)
{
const int total_trees = 600;
vector<TreeNode> vectors(total_trees);
if(argc<4) {
cout << "Please use X,Y,Z coordinates to find in the tree of trees"<<endl;
cout << argv[0] << " X Y Z"<<endl;
}
Vector3 coordinates;
coordinates.X = atof(argv[1]);
coordinates.Y = atof(argv[2]);
coordinates.Z = atof(argv[3]);
for (int i=0; i<total_trees; i++)
{
vectors[i] = TreeNode();
vectors[i].position.X = i*6;
vectors[i].position.Y = i*2;
vectors[i].position.Z = i;
vectors[i].healed = false;
}
shared_ptr<KDTree> TreeInstances = KDTree::MakeFromPoints(vectors);
cout << "Trees created: " << total_trees << endl;
cout << "---------------------"<<endl;
cout << "Searching for the index of the tree nearest to "<<
coordinates.ToString()<<endl;
int index = TreeInstances->FindNearest(coordinates);
cout << "Nearest tree at: " << index << vectors[index].position.ToString();
TreeInstances.reset();
vectors.clear();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment