Created
November 9, 2015 14:46
-
-
Save KristofferC/15beed394361a325ff33 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using NearestNeighbors | |
import NearestNeighbors: TreeData, find_split, getleft, getright, | |
HyperRectangle, compute_bbox, select_spec!, KDNode, Euclidean | |
using Base.Threads | |
function build_KDTree_start{T}(index::Int, | |
data::Matrix{T}, | |
tree_data::TreeData, | |
indices::Vector{Int}, | |
low::Int, | |
high::Int, | |
lows::Vector{Int}, | |
highs::Vector{Int}, | |
hyper_rec::HyperRectangle{T}, | |
hyper_recs::Vector{HyperRectangle{T}}, | |
nodes::Vector{KDNode{T}}, | |
idxs::Vector{Int}) | |
n_p = high - low + 1 # Points left | |
if index in idxs | |
offset = idxs[1] - 1 | |
lows[index - offset] = low | |
highs[index - offset] = high | |
hyper_recs[index - offset] = hyper_rec | |
return | |
end | |
mid_idx = find_split(low, tree_data.leafsize, n_p) | |
split_dim = 1 | |
max_spread = zero(T) | |
# Find dimension and and spread where the spread is maximal | |
for d in 1:size(data, 1) | |
spread = hyper_rec.maxes[d] - hyper_rec.mins[d] | |
if spread > max_spread | |
max_spread = spread | |
split_dim = d | |
end | |
end | |
lo = hyper_rec.mins[split_dim] | |
hi = hyper_rec.maxes[split_dim] | |
split_val = data[split_dim, indices[mid_idx]] | |
nodes[index] = KDNode{T}(lo, hi, split_val, split_dim) | |
select_spec!(indices, mid_idx, low, high, data, split_dim) | |
h_rect_left = HyperRectangle{T}(copy(hyper_rec.maxes), copy(hyper_rec.mins)) | |
h_rect_right = HyperRectangle{T}(copy(hyper_rec.maxes), copy(hyper_rec.mins)) | |
h_rect_left.maxes[split_dim] = split_val | |
build_KDTree_start(getleft(index), data, tree_data, indices, low, mid_idx - 1, | |
lows, highs, h_rect_left, hyper_recs, nodes, idxs) | |
h_rect_right.mins[split_dim] = split_val | |
build_KDTree_start(getright(index), data, tree_data, indices, mid_idx, high, lows, highs, | |
h_rect_right, hyper_recs, nodes, idxs) | |
end | |
function build_KDTree{T <: AbstractFloat}(index::Int, | |
data::Matrix{T}, | |
hyper_rec::HyperRectangle{T}, | |
nodes::Vector{KDNode{T}}, | |
indices::Vector{Int}, | |
low::Int, | |
high::Int, | |
tree_data::TreeData) | |
n_p = high - low + 1 # Points left | |
if n_p <= tree_data.leafsize | |
return | |
end | |
mid_idx = find_split(low, tree_data.leafsize, n_p) | |
split_dim = 1 | |
max_spread = zero(T) | |
# Find dimension and and spread where the spread is maximal | |
for d in 1:size(data, 1) | |
spread = hyper_rec.maxes[d] - hyper_rec.mins[d] | |
if spread > max_spread | |
max_spread = spread | |
split_dim = d | |
end | |
end | |
select_spec!(indices, mid_idx, low, high, data, split_dim) | |
split_val = data[split_dim, indices[mid_idx]] | |
lo = hyper_rec.mins[split_dim] | |
hi = hyper_rec.maxes[split_dim] | |
nodes[index] = KDNode{T}(lo, hi, split_val, split_dim) | |
# Call the left sub tree with an updated hyper rectangle | |
hyper_rec.maxes[split_dim] = split_val | |
build_KDTree(getleft(index), data, hyper_rec, nodes, | |
indices,low, mid_idx - 1 , tree_data) | |
hyper_rec.maxes[split_dim] = hi # Restore the hyper rectangle | |
# Call the right sub tree with an updated hyper rectangle | |
hyper_rec.mins[split_dim] = split_val | |
build_KDTree(getright(index), data, hyper_rec, nodes, | |
indices, mid_idx, high, tree_data) | |
# Restore the hyper rectangle | |
hyper_rec.mins[split_dim] = lo | |
end | |
function KDTree{T <: AbstractFloat}(data::Matrix{T}, leafsize::Int = 10) | |
nthreads_active = 2^trunc(Int, log2(nthreads())) | |
n_d = size(data, 1) | |
n_p = size(data, 2) | |
lows = zeros(Int, nthreads_active) | |
highs = zeros(Int, nthreads_active) | |
metric = Euclidean() | |
reorder = false | |
leafsize = 10 | |
tree_data = TreeData(data, leafsize) | |
hyper_rec = compute_bbox(data) | |
hyper_recs = Array(HyperRectangle{T}, nthreads_active) | |
nodes = Array(KDNode{T}, tree_data.n_internal_nodes) | |
indices = collect(1:n_p) | |
idxs = collect(nthreads_active : 2 * nthreads_active -1) | |
build_KDTree_start(1, data, tree_data, indices, 1, n_p, lows, highs, hyper_rec, hyper_recs, nodes, idxs) | |
@threads all for i in 1:nthreads_active | |
lo = lows[i] | |
hi = highs[i] | |
idx = idxs[i] | |
hyper_rec = hyper_recs[i] | |
build_KDTree(idx, data, hyper_rec, nodes, indices, lo, hi, tree_data) | |
end | |
return NearestNeighbors.KDTree(data, hyper_rec, indices, metric, nodes, tree_data, reorder) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment