Skip to content

Instantly share code, notes, and snippets.

@foysavas
Created December 17, 2009 05:48
Show Gist options
  • Save foysavas/258553 to your computer and use it in GitHub Desktop.
Save foysavas/258553 to your computer and use it in GitHub Desktop.
import std.stdio;
class KdTree {
private int _dimensions;
int dimensions() {
return _dimensions;
}
KdNode *rootNode;
private int _depth = -1;
int depth() {
return _depth;
}
this(int __dimensions) {
_dimensions = __dimensions;
}
KdNode* searchByComparison(double[] vec, KdNode *startNode = null) {
if (rootNode is null)
return null;
if (startNode is null)
startNode = rootNode;
if (startNode.leftNode is null && startNode.rightNode is null)
return startNode;
short dir = startNode.arbCompareAtDepth(vec);
if (dir == -1) {
if (startNode.leftNode is null)
return startNode;
else
return searchByComparison(vec, startNode.leftNode);
} else {
if (startNode.rightNode is null)
return startNode;
else
return searchByComparison(vec, startNode.rightNode);
}
return null;
}
void addNode(KdNode *node) {
assert(node.vector.length == dimensions);
if (rootNode == null) {
node.depthInTree = 0;
rootNode = node;
_depth = _depth + 1;
} else {
KdNode *lastNode = searchByComparison(node.vector);
short dir = lastNode.arbCompareAtDepth(node.vector);
if (dir == -1)
lastNode.leftNode = node;
else
lastNode.rightNode = node;
node.parentNode = lastNode;
node.depthInTree = lastNode.depthInTree + 1;
if (node.depthInTree > depth)
_depth = node.depthInTree;
}
}
}
class KdNode {
KdNode *parentNode;
KdNode *leftNode;
KdNode *rightNode;
int depthInTree;
private double[] _vector;
double[] vector() {
return _vector;
}
this(double[] __vector) {
_vector = __vector;
}
// -1 left
// 0 exact match
// 1 right
short compareAtDepth(double[] ovector) {
if (vector[depthInTree % vector.length] > ovector[depthInTree % vector.length]) {
return -1;
} else if (vector[depthInTree % vector.length] < ovector[depthInTree % vector.length]) {
return 1;
} else {
return 0;
}
}
short arbCompareAtDepth(double[] ovector) {
short dir = compareAtDepth(ovector);
// arbitrary direction for tied comparisons
if (dir == 0)
dir = 1;
return dir;
}
}
unittest {
writeln("Testing KDTree");
writeln("-- Simple 1-D tree");
KdTree tree1 = new KdTree(1);
assert(tree1.rootNode == null);
assert(tree1.depth == -1);
assert(tree1.searchByComparison([0]) is null);
KdNode node0 = new KdNode([0]);
tree1.addNode(&node0);
assert(tree1.rootNode != null);
assert(tree1.depth == 0);
assert(node0.depthInTree == 0);
assert(tree1.rootNode == &node0);
assert(node0.compareAtDepth([0]) == 0);
assert(node0.compareAtDepth([-5]) == -1);
assert(node0.compareAtDepth([2]) == 1);
assert(tree1.searchByComparison([0]) == &node0);
KdNode node1 = new KdNode([1]);
tree1.addNode(&node1);
assert(tree1.depth == 1);
assert(tree1.rootNode.rightNode != null);
assert(tree1.rootNode.rightNode == &node1);
assert(tree1.searchByComparison([-1]) == &node0);
assert(tree1.searchByComparison([5]) == &node1);
writeln("---- Passed!");
writeln("-- Simple 2-D tree");
KdTree tree2 = new KdTree(2);
assert(tree2.rootNode == null);
assert(tree2.depth == -1);
assert(tree2.searchByComparison([0,0]) is null);
KdNode node0_0 = new KdNode([0,0]);
tree2.addNode(&node0_0);
assert(tree2.rootNode != null);
assert(tree2.depth == 0);
assert(node0_0.depthInTree == 0);
assert(tree2.rootNode == &node0_0);
assert(node0_0.compareAtDepth([0,0]) == 0);
assert(node0_0.compareAtDepth([-5,0]) == -1);
assert(node0_0.compareAtDepth([2,0]) == 1);
assert(tree2.searchByComparison([0,0]) == &node0_0);
KdNode node1_0 = new KdNode([1,0]);
tree2.addNode(&node1_0);
assert(tree2.depth == 1);
assert(tree2.rootNode.rightNode != null);
assert(tree2.rootNode.rightNode == &node1_0);
assert(tree2.searchByComparison([-1,0]) == &node0_0);
assert(tree2.searchByComparison([5,0]) == &node1_0);
writeln("---- Passed!");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment