Skip to content

Instantly share code, notes, and snippets.

@fsmv
Last active August 29, 2015 14:04
Show Gist options
  • Save fsmv/311ba8bb550bd6d36851 to your computer and use it in GitHub Desktop.
Save fsmv/311ba8bb550bd6d36851 to your computer and use it in GitHub Desktop.
A K-D Tree implementation which preforms a nearest neighbor seach on any point type and returns the distance between them in any type using a user-defined distance function.
/**
* A K-D Tree implementation which preforms a nearest neighbor seach on any point type and returns the
* distance between them in any type using a user-defined distance function.
*
* Rather than returning the actual nearest object this class returns the index of the nearest object in
* the original vector that was passed in.
*
* Provided under the MIT License
* Copyright (c) 2014 Andrew Kallmeyer <fsmv@sapium.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#pragma once
#include <vector>
struct Node {
int index;
Node* left;
Node* right;
Node(int index) : index(index) {}
bool isLeaf() { return left == nullptr && right == nullptr; }
};
template <typename PointType, typename DistType = double>
class KDTree {
public:
KDTree(DistType (*getDimVal)(const PointType&, int dim), int dimension = 2);
~KDTree(void);
/*
* add all entries in the array to the data structure
*
* Note: If two items have the same coordinates, then do not add the new item
* that has the same coordinates as another item.
*/
void build(const std::vector<PointType> &c);
/*
* Return a pointer to the entry that is closest to the given coordinates.
*/
int getNearest(const PointType &point, DistType *dist) const;
private:
Node* root;
std::vector<int> dataIndicies;
DistType getDist(const PointType &lhs, const PointType &rhs) const;
bool comp(int dim, const PointType &lhs, const PointType &rhs) const;
DistType (*getDimVal)(const PointType&, int dim);
int dimension;
void destroy(Node* root);
/*
* Recursively builds a kd-tree, n is the length of c
*/
Node* build(int start, int n, int depth);
/*
* Recursively finds the nearest neighbor to a given point (x, y) root is the root of the kd-tree to search and currResult is the current closest (arbitrary initial choice)
*/
int getNearest(const PointType &point, Node* node, int best, int depth = 0) const;
};
//======= Implementation =======
#include <cmath>
#include <algorithm>
#include <functional>
#include <iterator>
template<typename T, typename V>
KDTree<T, V>::KDTree(V (*getDimVal)(const T&, int dim), int dimension) : getDimVal(getDimVal), dimension(dimension) {}
template<typename T, typename V>
KDTree<T, V>::~KDTree(void) {
if(root != nullptr) {
destroy(root);
}
}
template<typename T, typename V>
void KDTree<T, V>::destroy(Node* root) {
if(root->left != nullptr){
destroy(root->left);
}
if(root->right != nullptr){
destroy(root->right);
}
delete root;
}
template<typename T, typename V>
void KDTree<T, V>::build(const std::vector<T> &c) {
data = c;
dataIndicies.reserve(data.size());
for(unsigned int i = 0; i < data.size(); ++i) {
dataIndicies.push_back(i);
}
root = build(0, c.size(), 0);
}
template<typename T, typename V>
Node* KDTree<T, V>::build(int start, int n, int depth) {
Node* result;
if(n == 0) {
return nullptr;
}else if(n == 1) {
result = new Node(dataIndicies[start]);
result->left = nullptr;
result->right = nullptr;
return result;
}
std::sort(dataIndicies.begin() + start, dataIndicies.begin() + start + n, [&](int lhs, int rhs){ return comp(depth % dimension, data[lhs], data[rhs]); });
const int halfLength = n/2;
result = new Node(dataIndicies[start + halfLength]);
result->left = build(start, halfLength, depth + 1);
result->right = build(start + halfLength + 1, (n % 2 == 0 ? -1 : 0) + halfLength, depth + 1);
return result;
}
template<typename T, typename V>
int KDTree<T, V>::getNearest(const T &point, V *dist) const {
int result = getNearest(point, root, root->index);
*dist = (V) std::sqrt((double) getDist(point, data[result]));
return result;
}
template<typename T, typename V>
int KDTree<T, V>::getNearest(const T &point, Node* node, int best, int depth) const {
bool leftOf = comp(depth % dimension, point, data[node->index]);
Node *childNear = leftOf ? node->left : node->right;
Node *childFar = leftOf ? node->right : node->left;
if(getDist(point, data[node->index]) < getDist(point, data[best])) {
best = node->index;
}
if(childNear != nullptr) {
best = getNearest(point, childNear, best, depth+1);
}
V axisDist = getDimVal(point, depth % dimension) - getDimVal(data[node->index], depth % dimension);
axisDist *= axisDist; //getDist is a squared distance to save time, this needs to be squared also
if(axisDist <= getDist(point, data[best])) {
if(childFar != nullptr) {
best = getNearest(point, childFar, best, depth+1);
}
}
return best;
}
template<typename T, typename V>
V KDTree<T,V>::getDist(const T &lhs, const T &rhs) const {
V sum = (V) 0;
for(int i = 0; i < dimension; ++i) {
V val = getDimVal(lhs, i) - getDimVal(rhs, i);
sum += val * val;
}
return sum;
}
template<typename T, typename V>
bool KDTree<T,V>::comp(int dim, const T &lhs, const T &rhs) const {
return getDimVal(lhs, dim) < getDimVal(rhs, dim);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment