Skip to content

Instantly share code, notes, and snippets.

@armollica
Last active April 7, 2023 07:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save armollica/64ffc3bd8fc76c5657719a842e39c4e3 to your computer and use it in GitHub Desktop.
Save armollica/64ffc3bd8fc76c5657719a842e39c4e3 to your computer and use it in GitHub Desktop.
K-D Tree Nearest Neighbor

k-d tree nearest neighbor search (1-NN). The red dot is the nearest neighbor. Orange dots are scanned and not selected.

Compare to nearest neighbor search using quadtrees from this block. The k-d tree technique seems to scan more points, although the process of limiting the search set is different so this isn't really a direct measure of which is more efficient.

Here's a more up-to-date version of this block that works for k nearest neighbors.

<html>
<head>
<style>
.line {
fill: none;
stroke: #ccc;
}
.point {
fill: #999;
stroke: #fff;
}
.point.scanned {
fill: orange;
stroke: #999;
}
.point.selected {
fill: red;
stroke: #999;
}
.halo {
fill: none;
stroke: red;
}
</style>
</head>
<body>
<script src="kd-tree.js"></script>
<script src="https://d3js.org/d3.v3.min.js" charset="utf-8"></script>
<script>
var width = 960,
height = 500;
var svg = d3.select("body").append("svg")
.attr("width", width)
.attr("height", height);
var data = d3.range(2000)
.map(function() {
return {
x: width * Math.random(),
y: height * Math.random(),
value: d3.random.normal()() // just for testing purposes
};
});
var tree = KDTree()
.x(function(d) { return d.x; })
.y(function(d) { return d.y; })
(data);
svg.append("g").attr("class", "lines")
.selectAll(".line").data(tree.lines([[0,0], [width, height]]))
.enter().append("path")
.attr("class", "line")
.attr("d", d3.svg.line());
var points = svg.append("g").attr("class", "points")
.selectAll(".point").data(tree.flatten())
.enter().append("circle")
.attr("class", "point")
.attr("cx", function(d) { return d.location[0]; })
.attr("cy", function(d) { return d.location[1]; })
.attr("r", 4);
var halo = svg.append("circle").attr("class", "halo");
update([width/3, height/2]);
svg.append("rect").attr("class", "event-canvas")
.attr("width", width)
.attr("height", height)
.attr("fill-opacity", 0)
.on("mousemove", function() { update(d3.mouse(this)); });
function update(point) {
var nearest = tree.find(point);
points
.classed("scanned", function(d) { return nearest.scannedNodes.indexOf(d) !== -1; })
.classed("selected", function(d) { return d === nearest.node; });
halo
.attr("cx", point[0])
.attr("cy", point[1])
.attr("r", nearest.distance);
};
</script>
</body>
</html>
function Node(location, axis, subnodes, datum) {
this.location = location;
this.axis = axis;
this.subnodes = subnodes; // = children nodes = [left child, right child]
this.datum = datum;
};
Node.prototype.toArray = function() {
var array = [
this.location,
this.subnodes[0] ? this.subnodes[0].toArray() : null,
this.subnodes[0] ? this.subnodes[1].toArray() : null
];
array.axis = this.axis;
return array;
};
Node.prototype.flatten = function() {
var left = this.subnodes[0] ? this.subnodes[0].flatten() : null,
right = this.subnodes[1] ? this.subnodes[1].flatten() : null;
return left && right ? [this].concat(left, right) :
left ? [this].concat(left) :
right ? [this].concat(right) :
[this];
};
// Nearest neighbor search (1-NN)
Node.prototype.find = function(target) {
var guess = this,
bestDist = Infinity,
scannedNodes = []; // keep track of these just for testing purpose
search(this);
return {
node: guess,
distance: bestDist,
scannedNodes: scannedNodes
};
// 1-NN algorithm outlined here:
// http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf
function search(node) {
if (node === null) return;
scannedNodes.push(node);
// If the current location is better than the best known location,
// update the best known location
var nodeDist = distance(node.location, target);
if (nodeDist < bestDist) {
bestDist = nodeDist;
guess = node;
}
// Recursively search the half of the tree that contains the target
var side = target[node.axis] < node.location[node.axis] ? "left" : "right";
if (side == "left") {
search(node.subnodes[0]);
var otherNode = node.subnodes[1];
}
else {
search(node.subnodes[1]);
var otherNode = node.subnodes[0];
}
// If the candidate hypersphere crosses this splitting plane, look on the
// other side of the plane by examining the other subtree
if (otherNode !== null) {
var i = node.axis;
var delta = Math.abs(node.location[i] - target[i]);
if (delta < bestDist) {
search(otherNode);
}
}
}
};
// Only works for 2D
Node.prototype.lines = function(extent) {
var x0 = extent[0][0],
y0 = extent[0][1],
x1 = extent[1][0],
y1 = extent[1][1],
x = this.location[0],
y = this.location[1];
if (this.axis == 0) {
var line = [[x, y0], [x, y1]];
var left = this.subnodes[0] ?
this.subnodes[0].lines([[x0, y0], [x, y1]]) : null;
var right = this.subnodes[1] ?
this.subnodes[1].lines([[x, y0], [x1, y1]]) : null;
}
else if (this.axis == 1) {
var line = [[x0, y], [x1, y]];
var left = this.subnodes[0] ?
this.subnodes[0].lines([[x0, y0], [x1, y]]) : null;
var right = this.subnodes[1] ?
this.subnodes[1].lines([[x0, y], [x1, y1]]) : null;
}
return left && right ? [line].concat(left, right) :
left ? [line].concat(left) :
right ? [line].concat(right) :
[line];
}
function KDTree() {
var x = function(d) { return d[0]; },
y = function(d) { return d[1]; };
function tree(data) {
var points = data.map(function(d) {
var point = [x(d), y(d)];
point.datum = d;
return point;
});
return treeify(points, 0);
}
tree.x = function(_) {
if (!arguments.length) return x;
x = _;
return tree;
};
tree.y = function(_) {
if (!arguments.length) return y;
y = _;
return tree;
};
return tree;
// Adapted from https://en.wikipedia.org/wiki/K-d_tree
function treeify(points, depth) {
try { var k = points[0].length; }
catch (e) { return null; }
// Select axis based on depth so that axis cycles through all valid values
var axis = depth % k;
// TODO: To speed up, consider splitting points based on approximation of
// median; take median of random sample of points (perhaps of 1/10th
// of the points)
// Sort point list and choose median as pivot element
points.sort(function(a, b) { return a[axis] - b[axis]; });
i_median = Math.floor(points.length / 2);
// Create node and construct subtrees
var point = points[i_median],
left_points = points.slice(0, i_median),
right_points = points.slice(i_median + 1);
return new Node(
point,
axis,
[treeify(left_points, depth + 1), treeify(right_points, depth + 1)],
point.datum
);
}
}
function min(array, accessor) {
return array
.map(function(d) { return accessor(d); })
.reduce(function(a, b) { return a < b ? a : b; });
}
function max(array, accessor) {
return array
.map(function(d) { return accessor(d); })
.reduce(function(a, b) { return a > b ? a : b; });
}
function get(key) { return function(d) { return d[key]; }; }
// TODO: Make distance function work for k-dimensions
// Euclidean distance between two 2D points
function distance(p0, p1) {
return Math.sqrt(Math.pow(p1[0] - p0[0], 2) + Math.pow(p1[1] - p0[1], 2));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment