Skip to content

Instantly share code, notes, and snippets.

@chengluyu
Created January 18, 2018 15:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chengluyu/095a8dff5e43bff1a9941323d158e013 to your computer and use it in GitHub Desktop.
Save chengluyu/095a8dff5e43bff1a9941323d158e013 to your computer and use it in GitHub Desktop.
My Naïve Implement of K-d Tree
#include <cassert>
#include <chrono>
#include <iostream>
#include <iterator>
#include <type_traits>
#include <cmath>
#include <vector>
#include <random>
class Timer {
std::chrono::time_point<std::chrono::high_resolution_clock> start_, end_;
public:
inline void start() {
start_ = std::chrono::high_resolution_clock::now();
}
inline void stop() {
end_ = std::chrono::high_resolution_clock::now();
}
inline double elapsedSeconds() const {
std::chrono::duration<double> d = end_ - start_;
return d.count();
}
};
template <typename TotalOrderComparable>
struct Range {
TotalOrderComparable min;
TotalOrderComparable max;
Range(TotalOrderComparable mi, TotalOrderComparable mx) : min(mi), max(mx) { }
inline bool contains(TotalOrderComparable value) const {
return min < value && value < max;
}
inline bool overlap(const Range &rhs) {
return !(max < rhs.min || min > rhs.max);
}
static Range neighbourhood(TotalOrderComparable center, TotalOrderComparable radius) {
return Range{ center - radius, center + radius };
}
};
constexpr float square(float x) {
return x * x;
}
template <typename RealType>
RealType epsilon = 1e-7;
struct Vector2F {
float x, y;
using Axis = float Vector2F::*;
static const Axis axes[3];
Vector2F(float x, float y) : x(x), y(y) { }
inline float dot(const Vector2F &rhs) const {
return x * rhs.x + y * rhs.y;
}
inline float distance(const Vector2F &rhs) const {
return std::sqrt(square(x - rhs.x) + square(y - rhs.y));
}
inline bool operator== (const Vector2F &rhs) const {
return (x - rhs.x) < epsilon<float> &&
(y - rhs.y) < epsilon<float>;
}
inline bool operator!= (const Vector2F &rhs) const {
return !(*this == rhs);
}
};
std::ostream &operator<< (std::ostream &out, const Vector2F &vec) {
return out << '(' << vec.x << ", " << vec.y << ')';
}
const Vector2F::Axis Vector2F::axes[3] = { &Vector2F::x, &Vector2F::y };
struct KdTreeNode {
Vector2F data;
Vector2F::Axis axis;
KdTreeNode *left_child;
KdTreeNode *right_child;
float value;
Range<float> range = { 0.0f, 0.0f };
KdTreeNode(const Vector2F &data,
Vector2F::Axis axis,
Range<float> range,
KdTreeNode *lc = nullptr,
KdTreeNode *rc = nullptr)
: data(data), axis(axis), left_child(lc), right_child(rc), value(data.*axis), range(range) { }
~KdTreeNode() {
delete left_child;
delete right_child;
}
static void dump(KdTreeNode *node, std::ostream &out, size_t depth = 0) {
for (size_t i = 0; i < depth; i++)
out << " ";
if (node) {
out << node->data << '\n';
if (node->left_child || node->right_child) {
dump(node->left_child, out, depth + 1);
dump(node->right_child, out, depth + 1);
}
} else {
out << '*' << '\n';
}
}
template <typename RandomAccessIt>
static KdTreeNode *build(RandomAccessIt begin, RandomAccessIt end, size_t depth = 0) {
if (begin == end)
return nullptr;
auto middle = begin + std::distance(begin, end) / 2;
auto axis = Vector2F::axes[depth % 2];
auto comp = [axis](auto lhs, auto rhs) {
return lhs.*axis < rhs.*axis;
};
std::nth_element(begin, middle, end, comp);
auto result = std::minmax_element(begin, end, comp);
return new KdTreeNode(*middle,
Vector2F::axes[depth % 2],
Range{ (*result.first).*axis, (*result.second).*axis },
build(begin, middle, depth + 1),
build(middle + 1, end, depth + 1));
}
};
class KdTreeQuery {
Vector2F source_point_;
Vector2F *nearest_point_;
float min_distance_;
public:
explicit KdTreeQuery(const Vector2F &source)
: source_point_(source),
nearest_point_(nullptr),
min_distance_(std::numeric_limits<float>::max()) { }
void query(KdTreeNode *node) {
if (node == nullptr)
return;
auto distance = source_point_.distance(node->data);
if (distance < min_distance_) {
min_distance_ = distance;
nearest_point_ = &node->data;
}
if (source_point_.*node->axis < node->value) {
query(node->left_child);
if (source_point_.*node->axis + min_distance_ > node->value)
query(node->right_child);
} else {
query(node->right_child);
if (source_point_.*node->axis - min_distance_ < node->value)
query(node->left_child);
}
}
inline bool isNull() const {
return nearest_point_ == nullptr;
}
inline Vector2F result() const {
return *nearest_point_;
}
inline float distance() const {
return min_distance_;
}
};
std::vector<Vector2F> generateRandomPoints(size_t n) {
std::random_device rd;
std::mt19937_64 engine(rd());
std::uniform_real_distribution<float> dist{ 0, +100.0f };
std::vector<Vector2F> points;
points.reserve(n);
for (size_t i = 0; i < n; i++)
points.emplace_back(dist(engine), dist(engine));
return points;
}
Vector2F generateRandomPoint() {
std::random_device rd;
std::mt19937_64 engine(rd());
std::uniform_real_distribution<float> dist{ 0, +100.0f };
return Vector2F{ dist(engine), dist(engine) };
}
int main() {
std::vector<Vector2F> points = generateRandomPoints(100);
Timer timer;
KdTreeNode *root = KdTreeNode::build(points.begin(), points.end());
while (true) {
auto source = generateRandomPoint();
KdTreeQuery query{source};
query.query(root);
assert(!query.isNull());
auto gt = *std::min_element(points.begin(), points.end(), [source](auto lhs, auto rhs) {
return source.distance(lhs) < source.distance(rhs);
});
if (query.result() != gt) {
std::cout << "Test failed.\n";
std::cout << "Source point is " << source << '\n';
std::cout << "Point find by K-d tree is " << query.result() << " while the truth is " << gt << '\n';
std::copy(points.begin(), points.end(), std::ostream_iterator<Vector2F>(std::cout, ",\n"));
std::cout << "The tree:\n";
KdTreeNode::dump(root, std::cout);
break;
}
}
delete root;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment