Skip to content

Instantly share code, notes, and snippets.

@hoffrocket
Created November 3, 2015 22:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hoffrocket/d4ee2b805ae55634222c to your computer and use it in GitHub Desktop.
Save hoffrocket/d4ee2b805ae55634222c to your computer and use it in GitHub Desktop.
package j.nettytest;
/*
** JkKdTree.java by Julian Kent
**
** Licenced under the Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
**
** Licence summary:
** Under this licence you are free to:
** Share — copy and redistribute the material in any medium or format
** Adapt — remix, transform, and build upon the material
** The licensor cannot revoke these freedoms as long as you follow the license terms.
**
** Under the following terms:
** Attribution — You must give appropriate credit, provide a link to the license, and indicate
** if changes were made. You may do so in any reasonable manner, but not in any
** way that suggests the licensor endorses you or your use.
** NonCommercial — You may not use the material for commercial purposes.
** ShareAlike — If you remix, transform, or build upon the material, you must distribute your
** contributions under the same license as the original.
** No additional restrictions
** — You may not apply legal terms or technological measures that legally restrict
** others from doing anything the license permits.
**
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
**
** For additional licencing rights please contact jkflying@gmail.com
**
*/
import java.util.ArrayList;
import java.util.Arrays;
public abstract class JkKdTree {
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits
private static final int _bucketSize = 50;
private final int _dimensions;
private int _nodes;
private final Node root;
private final ArrayList<Node> nodeList = new ArrayList<Node>();
//prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits
private float[] mem_recycle;
//the starting values for bounding boxes, for easy access
private final float[] bounds_template;
//one big self-expanding array to keep all the node bounding boxes so that they stay in cache
// node bounds available at:
//low: 2 * _dimensions * node.index + 2 * dim
//high: 2 * _dimensions * node.index + 2 * dim + 1
private final ContiguousFloatArrayList nodeMinMaxBounds;
private JkKdTree(int dimensions) {
_dimensions = dimensions;
//initialise this big so that it ends up in 'old' memory
nodeMinMaxBounds = new ContiguousFloatArrayList(512 * 1024 / 8 + 2 * _dimensions);
mem_recycle = new float[_bucketSize * dimensions];
bounds_template = new float[2 * _dimensions];
Arrays.fill(bounds_template, Float.NEGATIVE_INFINITY);
for (int i = 0, max = 2 * _dimensions; i < max; i += 2)
bounds_template[i] = Float.POSITIVE_INFINITY;
//and.... start!
root = new Node();
}
public int nodes() {
return _nodes;
}
public int size() {
return root.entries;
}
public int addPoint(float[] location, long payload) {
Node addNode = root;
//Do a Depth First Search to find the Node where 'location' should be stored
while (addNode.pointLocations == null) {
addNode.expandBounds(location);
if (location[addNode.splitDim] < addNode.splitVal)
addNode = nodeList.get(addNode.lessIndex);
else
addNode = nodeList.get(addNode.moreIndex);
}
addNode.expandBounds(location);
int nodeSize = addNode.add(location, payload);
if (nodeSize % _bucketSize == 0)
//try splitting again once every time the node passes a _bucketSize multiple
//in case it is full of points of the same location and won't split
addNode.split();
return root.entries;
}
public ArrayList<SearchResult> nearestNeighbours(float[] searchLocation, int K) {
IntStack stack = new IntStack();
PrioQueue results = new PrioQueue(K, true);
stack.push(root.index);
int added = 0;
while (stack.size() > 0) {
int nodeIndex = stack.pop();
if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) {
Node node = nodeList.get(nodeIndex);
if (node.pointLocations == null)
node.search(searchLocation, stack);
else
added += node.search(searchLocation, results);
}
}
ArrayList<SearchResult> returnResults = new ArrayList<SearchResult>(K);
float[] priorities = results.priorities;
long[] elements = results.elements;
for (int i = 0; i < K; i++) {//forward (closest first)
SearchResult s = new SearchResult(priorities[i], elements[i]);
returnResults.add(s);
}
return returnResults;
}
public ArrayList<Long> ballSearch(float[] searchLocation, double radius) {
IntStack stack = new IntStack();
ArrayList<Long> results = new ArrayList<Long>();
stack.push(root.index);
while (stack.size() > 0) {
int nodeIndex = stack.pop();
if (radius > pointRectDist(nodeIndex, searchLocation)) {
Node node = nodeList.get(nodeIndex);
if (node.pointLocations == null)
stack.push(node.moreIndex).push(node.lessIndex);
else
node.searchBall(searchLocation, radius, results);
}
}
return results;
}
public ArrayList<Long> rectSearch(float[] mins, float[] maxs) {
IntStack stack = new IntStack();
ArrayList<Long> results = new ArrayList<Long>();
stack.push(root.index);
while (stack.size() > 0) {
int nodeIndex = stack.pop();
if (overlaps(mins, maxs, nodeIndex)) {
Node node = nodeList.get(nodeIndex);
if (node.pointLocations == null)
stack.push(node.moreIndex).push(node.lessIndex);
else
node.searchRect(mins, maxs, results);
}
}
return results;
}
abstract float pointRectDist(int offset, final float[] location);
abstract float pointDist(float[] arr, float[] location, int index);
boolean contains(float[] arr, float[] mins, float[] maxs, int index) {
int offset = (index + 1) * mins.length;
for (int i = mins.length; i-- > 0; ) {
float d = arr[--offset];
if (mins[i] > d | d > maxs[i])
return false;
}
return true;
}
boolean overlaps(float[] mins, float[] maxs, int offset) {
offset *= (2 * maxs.length);
final float[] array = nodeMinMaxBounds.array;
for (int i = 0; i < maxs.length; i++, offset += 2) {
double bmin = array[offset], bmax = array[offset + 1];
if (mins[i] > bmax | maxs[i] < bmin)
return false;
}
return true;
}
public static class Euclidean extends JkKdTree {
public Euclidean(int dims) {
super(dims);
}
float pointRectDist(int offset, final float[] location) {
offset *= (2 * super._dimensions);
float distance = 0;
final float[] array = super.nodeMinMaxBounds.array;
for (int i = 0; i < location.length; i++, offset += 2) {
float diff = 0;
float bv = array[offset];
float lv = location[i];
if (bv > lv)
diff = bv - lv;
else {
bv = array[offset + 1];
if (lv > bv)
diff = lv - bv;
}
distance += sqr(diff);
}
return distance;
}
float pointDist(float[] arr, float[] location, int index) {
float distance = 0;
int offset = (index + 1) * super._dimensions;
for (int i = super._dimensions; i-- > 0; ) {
distance += sqr(arr[--offset] - location[i]);
}
return distance;
}
}
public static class Manhattan extends JkKdTree {
public Manhattan(int dims) {
super(dims);
}
float pointRectDist(int offset, final float[] location) {
offset *= (2 * super._dimensions);
float distance = 0;
final float[] array = super.nodeMinMaxBounds.array;
for (int i = 0; i < location.length; i++, offset += 2) {
float diff = 0;
float bv = array[offset];
float lv = location[i];
if (bv > lv)
diff = bv - lv;
else {
bv = array[offset + 1];
if (lv > bv)
diff = lv - bv;
}
distance += (diff);
}
return distance;
}
float pointDist(float[] arr, float[] location, int index) {
float distance = 0;
int offset = (index + 1) * super._dimensions;
for (int i = super._dimensions; i-- > 0; ) {
distance += Math.abs(arr[--offset] - location[i]);
}
return distance;
}
}
public static class WeightedManhattan extends JkKdTree {
float[] weights;
public WeightedManhattan(int dims) {
super(dims);
}
public void setWeights(float[] newWeights) {
weights = newWeights;
}
float pointRectDist(int offset, final float[] location) {
offset *= (2 * super._dimensions);
float distance = 0;
final float[] array = super.nodeMinMaxBounds.array;
for (int i = 0; i < location.length; i++, offset += 2) {
double diff = 0;
double bv = array[offset];
double lv = location[i];
if (bv > lv)
diff = bv - lv;
else {
bv = array[offset + 1];
if (lv > bv)
diff = lv - bv;
}
distance += (diff) * weights[i];
}
return distance;
}
float pointDist(float[] arr, float[] location, int index) {
float distance = 0;
int offset = (index + 1) * super._dimensions;
for (int i = super._dimensions; i-- > 0; ) {
distance += Math.abs(arr[--offset] - location[i]) * weights[i];
}
return distance;
}
}
//NB! This Priority Queue keeps things with the LOWEST priority.
//If you want highest priority items kept, negate your values
private static class PrioQueue {
long[] elements;
float[] priorities;
private double minPrio;
private int size;
PrioQueue(int size, boolean prefill) {
elements = new long[size];
priorities = new float[size];
Arrays.fill(priorities, Float.POSITIVE_INFINITY);
if (prefill) {
minPrio = Float.POSITIVE_INFINITY;
this.size = size;
}
}
//uses O(log(n)) comparisons and one big shift of size O(N)
//and is MUCH simpler than a heap --> faster on small sets, faster JIT
void addNoGrow(long value, float priority) {
int index = searchFor(priority);
int nextIndex = index + 1;
int length = size - index - 1;
System.arraycopy(elements, index, elements, nextIndex, length);
System.arraycopy(priorities, index, priorities, nextIndex, length);
elements[index] = value;
priorities[index] = priority;
minPrio = priorities[size - 1];
}
int searchFor(float priority) {
int i = size - 1;
int j = 0;
while (i >= j) {
int index = (i + j) >>> 1;
if (priorities[index] < priority)
j = index + 1;
else
i = index - 1;
}
return j;
}
double peekPrio() {
return minPrio;
}
}
public static class SearchResult {
public float distance;
public long payload;
SearchResult(float dist, long load) {
distance = dist;
payload = load;
}
}
private class Node {
//for accessing bounding box data
// - if trees weren't so unbalanced might be better to use an implicit heap?
int index;
//keep track of size of subtree
int entries;
//leaf
ContiguousFloatArrayList pointLocations;
LongList pointPayloads = new LongList();
//stem
//Node less, more;
int lessIndex, moreIndex;
int splitDim;
double splitVal;
Node() {
this(new float[_bucketSize * _dimensions]);
}
Node(float[] pointMemory) {
pointLocations = new ContiguousFloatArrayList(pointMemory);
index = _nodes++;
nodeList.add(this);
nodeMinMaxBounds.add(bounds_template);
}
void search(float[] searchLocation, IntStack stack) {
if (searchLocation[splitDim] < splitVal)
stack.push(moreIndex).push(lessIndex);//less will be popped first
else
stack.push(lessIndex).push(moreIndex);//more will be popped first
}
//returns number of points added to results
int search(float[] searchLocation, PrioQueue results) {
int updated = 0;
for (int j = entries; j-- > 0; ) {
float distance = pointDist(pointLocations.array, searchLocation, j);
if (results.peekPrio() > distance) {
updated++;
results.addNoGrow(pointPayloads.get(j), distance);
}
}
return updated;
}
void searchBall(float[] searchLocation, double radius, ArrayList<Long> results) {
for (int j = entries; j-- > 0; ) {
double distance = pointDist(pointLocations.array, searchLocation, j);
if (radius >= distance) {
results.add(pointPayloads.get(j));
}
}
}
void searchRect(float[] mins, float[] maxs, ArrayList<Long> results) {
for (int j = entries; j-- > 0; )
if (contains(pointLocations.array, mins, maxs, j))
results.add(pointPayloads.get(j));
}
void expandBounds(float[] location) {
entries++;
int mio = index * 2 * _dimensions;
for (int i = 0; i < _dimensions; i++) {
nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio++], location[i]);
nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio++], location[i]);
}
}
int add(float[] location, long load) {
pointLocations.add(location);
pointPayloads.add(load);
return entries;
}
void split() {
int offset = index * 2 * _dimensions;
double diff = 0;
for (int i = 0; i < _dimensions; i++) {
double min = nodeMinMaxBounds.array[offset];
double max = nodeMinMaxBounds.array[offset + 1];
if (max - min > diff) {
double mean = 0;
for (int j = 0; j < entries; j++)
mean += pointLocations.array[i + _dimensions * j];
mean = mean / entries;
double varianceSum = 0;
for (int j = 0; j < entries; j++)
varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]);
if (varianceSum > diff * entries) {
diff = varianceSum / entries;
splitVal = mean;
splitDim = i;
}
}
offset += 2;
}
//kill all the nasties
if (splitVal == Double.POSITIVE_INFINITY)
splitVal = Double.MAX_VALUE;
else if (splitVal == Double.NEGATIVE_INFINITY)
splitVal = Double.MIN_VALUE;
else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1])
splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim];
Node less = new Node(mem_recycle);//recycle that memory!
Node more = new Node();
lessIndex = less.index;
moreIndex = more.index;
//reduce garbage by factor of _bucketSize by recycling this array
float[] pointLocation = new float[_dimensions];
for (int i = 0; i < entries; i++) {
System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions);
long load = pointPayloads.get(i);
if (pointLocation[splitDim] < splitVal) {
less.expandBounds(pointLocation);
less.add(pointLocation, load);
} else {
more.expandBounds(pointLocation);
more.add(pointLocation, load);
}
}
if (less.entries * more.entries == 0) {
//one of them was 0, so the split was worthless. throw it away.
_nodes -= 2;//recall that bounds memory
nodeList.remove(moreIndex);
nodeList.remove(lessIndex);
} else {
//we won't be needing that now, so keep it for the next split to reduce garbage
mem_recycle = pointLocations.array;
pointLocations = null;
pointPayloads.clear();
pointPayloads = null;
}
}
}
private static class ContiguousFloatArrayList {
float[] array;
int size;
ContiguousFloatArrayList() {
this(300);
}
ContiguousFloatArrayList(int size) {
this(new float[size]);
}
ContiguousFloatArrayList(float[] data) {
array = data;
}
ContiguousFloatArrayList add(float[] da) {
if (size + da.length > array.length)
array = Arrays.copyOf(array, (array.length + da.length) * 2);
System.arraycopy(da, 0, array, size, da.length);
size += da.length;
return this;
}
}
private static class LongList {
long[] array;
int size;
LongList() {
this(16);
}
LongList(int size) {
array = new long[size];
}
void add(long l) {
if (size + 1 > array.length)
array = Arrays.copyOf(array, array.length + 1);
array[size] = l;
size ++;
}
long get(int index) {
return array[index];
}
void clear() {
size = 0;
}
}
private static class IntStack {
int[] array;
int size;
IntStack() {
this(64);
}
IntStack(int size) {
this(new int[size]);
}
IntStack(int[] data) {
array = data;
}
IntStack push(int i) {
if (size >= array.length)
array = Arrays.copyOf(array, (array.length + 1) * 2);
array[size++] = i;
return this;
}
int pop() {
return array[--size];
}
int size() {
return size;
}
}
static final double sqr(double d) {
return d * d;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment