Skip to content

Instantly share code, notes, and snippets.

@arrieta
Created November 8, 2021 14:22
Show Gist options
  • Save arrieta/54e945e99735ca6fb33216678134ef2b to your computer and use it in GitHub Desktop.
Save arrieta/54e945e99735ca6fb33216678134ef2b to your computer and use it in GitHub Desktop.
Trivial k-d tree implementation in C++
// Educational implementation of a k-d tree.
#include <algorithm>
#include <array>
#include <iostream>
#include <random>
#include <stdexcept>
#include <string>
#include <vector>
struct KDTree;
using Node = std::unique_ptr<KDTree>;
using Point = std::array<int, 3>;
using Points = std::vector<Point>;
struct KDTree {
Point p = {};
std::unique_ptr<KDTree> lc = {};
std::unique_ptr<KDTree> rc = {};
};
std::unique_ptr<KDTree> build(Points ps, unsigned int level = 0u) {
auto size = ps.size();
if (size == 0u) {
return nullptr;
}
auto k = level % ps[0].size();
auto pred = [k](auto p, auto q) { return p[k] < q[k]; };
std::sort(ps.begin(), ps.end(), pred);
auto beg = ps.begin();
auto mid = std::next(beg, size / 2u);
auto end = ps.end();
auto node = std::make_unique<KDTree>();
node->p = *mid;
node->lc = build({beg, mid}, level + 1u);
node->rc = build({std::next(mid, 1u), end}, level + 1u);
return node;
}
auto show(const KDTree* node, int level = 0) {
if (node == nullptr) {
return;
}
std::cout << std::string(4 * level, ' ') << "L" << level << "[" << node->p[0]
<< ", " << node->p[1] << ", " << node->p[2] << "]\n";
show(node->lc.get(), level + 1);
show(node->rc.get(), level + 1);
}
int main(int argc, char* argv[]) {
auto rng = std::default_random_engine(0u);
auto dis = std::uniform_int_distribution<>(-100, 100);
const auto N = argc == 1 ? 100 : std::stoi(argv[1]);
Points ps(N);
for (auto k = 0u; k < ps.size(); ++k) {
ps[k] = {dis(rng), dis(rng), dis(rng)};
}
show(build(ps).get());
}
@arrieta
Copy link
Author

arrieta commented Nov 8, 2021

Step-by-step construction view:

template <typename I>
void sort(I beg, I end, std::uint32_t level = 0u) {
  const auto size = std::distance(beg, end);

  if (size == 0u) {
    return;
  }

  auto k    = level % 3u;
  auto pred = [k](auto p, auto q) { return p[k] < q[k]; };
  std::sort(beg, end, pred);

  auto mid = std::next(beg, size / 2u);

  std::cout << std::string(4u * level, ' ') << (*mid) << "\n";

  level += 1u;
  kdtree::sort(beg, mid, level);
  kdtree::sort(std::next(mid), end, level);
}

@arrieta
Copy link
Author

arrieta commented Nov 8, 2021

Python visualization

import sys
import random
import numpy as np
import matplotlib.pyplot as plt


class Node:
    def __init__(self, point, lc=None, rc=None):
        self.point = point
        self.lc = lc
        self.rc = rc


def build(points, level=0):
    if len(points) == 0:
        return None

    points = sorted(points, key=lambda p: p[level % 2])

    m = len(points) // 2

    lc = build(points[0:m], level + 1)
    rc = build(points[m + 1:], level + 1)

    return Node(points[m], lc, rc)


def plot_points(ax, points):
    ax.scatter(points[:, 0], points[:, 1], s=3, color="k")
    return ax


def plot_node(ax, node, level=0, xmin=0, xmax=1, ymin=0, ymax=1):
    if node is None:
        return
    p = node.point
    k = level % 2

    if k == 0:
        a = [p[0], p[0]]
        b = [ymin, ymax]

        plot_node(ax,
                  node.lc,
                  level + 1,
                  xmin=xmin,
                  xmax=p[0],
                  ymin=ymin,
                  ymax=ymax)

        plot_node(ax,
                  node.rc,
                  level + 1,
                  xmin=p[0],
                  xmax=xmax,
                  ymin=ymin,
                  ymax=ymax)

        ax.plot(a, b, "k-", linewidth=1)
    else:
        a = [xmin, xmax]
        b = [p[1], p[1]]
        ax.plot(a, b, "k-", linewidth=1)

        plot_node(
            ax,
            node.lc,
            level + 1,
            xmin=xmin,
            xmax=xmax,
            ymin=ymin,
            ymax=p[1],
        )

        plot_node(ax,
                  node.rc,
                  level + 1,
                  xmin=xmin,
                  xmax=xmax,
                  ymin=p[1],
                  ymax=ymax)


def main():
    if len(sys.argv) == 1:
        n = 100
    else:
        n = int(sys.argv[1])

    points = np.random.rand(n, 2)

    ax = plt.subplot(111)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    plot_node(ax, build(points))
    plot_points(ax, points)

    plt.show()


if __name__ == "__main__":
    main()

@arrieta
Copy link
Author

arrieta commented Nov 8, 2021

Sample 2D partition of 1,024 random points in (0, 1) x (0, 1)

kdtree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment