Created
March 17, 2014 16:55
-
-
Save goretkin/9603371 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module BallTree | |
dist = (p,q)->norm(p-q) | |
macro argbest(metric,compare,candidates,query) | |
#returns index into candidates with best metric | |
local best_so_far_value = metric(candidates[1],query) | |
local best_so_far_index = 1 | |
for i in 1:length(candidates) | |
local m = metric(candidates[i],query) | |
if compare(m,best_so_far_value) | |
best_so_far_index = i | |
best_so_far_value = m | |
end | |
end | |
return best_so_far_index | |
end | |
gr | |
function argbest_set(metric,compare,candidates,query) | |
#returns element of candidates with best metric | |
local best_so_far_value | |
local best_so_far | |
for c in candidates | |
best_so_far_value = metric(c,query) | |
best_so_far = c | |
break | |
end | |
for c in candidates | |
local m = metric(c,query) | |
if compare(m,best_so_far_value) | |
best_so_far = c | |
best_so_far_value = m | |
end | |
end | |
best_so_far | |
end | |
type BallTreeNode{P} | |
#P is type of points in BallTree | |
data::Set{P} #set of points contained within this ball. | |
radius::Float64 | |
center::P | |
children::Set{BallTreeNode} | |
BallTreeNode() = new() | |
end | |
function randsample_set{P}(s::Set{P}) | |
N = length(s) | |
r = rand(1:N) | |
i = 1 | |
#emulate s[r], but Set is not indexable. | |
for y in s | |
if r==i | |
return y | |
end | |
i+=1 | |
end | |
end | |
function split{P}(data::Set{P}) | |
#return points that would be good exemplars for subtrees. | |
x = randsample_set(data) | |
a = argbest_set(dist,>,data,x) | |
b = argbest_set(dist,>,data,a) | |
(a,b) | |
end | |
function make_ball_tree{P}(data::Set{P}) | |
tree =BallTreeNode{P}() | |
tree.data = data | |
tree.center = mean(tree.data) | |
tree.radius = dist(argbest_set(dist,>,data,tree.center),tree.center) | |
tree.children=Set{BallTreeNode}() | |
if length(data)>20 | |
exemplars = split(data) | |
@assert(length(exemplars)==2) #binary tree | |
@assert(exemplars[1] != exemplars[2]) #must be unique | |
subtree1_data::Set{P} = Set{P}() | |
subtree2_data::Set{P} = Set{P}() | |
for d in data | |
if dist(d,exemplars[1]) < dist(d,exemplars[2]) | |
add!(subtree1_data,d) | |
else | |
add!(subtree2_data,d) | |
end | |
end | |
add!(tree.children,make_ball_tree(subtree1)) | |
add!(tree.children,make_ball_tree(subtree2)) | |
end | |
tree | |
end | |
function nearest_neighbor_linsearch{P}(data::Set{P},query::P) | |
return argbest_set(dist,<,data,query) | |
end | |
type UpperBound{P} | |
value::Float64 | |
best_point::P | |
UpperBound() = (new(Inf)) | |
end | |
function query_ball_tree{P}(tree::BallTreeNode{P},query::P) | |
b = UpperBound{P}() | |
query_ball_tree!(tree,query,b) | |
return b.best_point | |
end | |
function query_ball_tree!{P}(tree::BallTreeNode{P},query::P,bound::UpperBound) | |
#Recursively depth-first search down tree and prune out branches that cannot yield a better value than what is in bound. Update bound if something better is found. | |
best_possible = dist(query,tree.center)-tree.radius | |
#if the query is inside the ball, then best_possible is negative. | |
if(best_possible > bound.value) | |
#don't explore this node further | |
return | |
end | |
if length(tree.children)==0 | |
#leaf node | |
p = nearest_neighbor_linsearch(tree.data,query) | |
d = dist(p,query) | |
if bound.value > d | |
bound.value=d | |
bound.best_point=p | |
end | |
return | |
end | |
@assert(length(tree.children)==2) | |
children = [c for c in tree.children] | |
d1 = dist(children[1].center,query) | |
d2 = dist(children[2].center,query) | |
if d1<d2 | |
query_ball_tree!(children[1],query,bound) | |
query_ball_tree!(children[2],query,bound) | |
else | |
query_ball_tree!(children[2],query,bound) | |
query_ball_tree!(children[1],query,bound) | |
end | |
end | |
function depth(tree::BallTreeNode) | |
if length(tree.children)==0 | |
return 0 | |
end | |
return 1 + max([depth(child) for child in tree.children]) | |
end | |
function nodes(tree::BallTreeNode) | |
s = sum([nodes(child) for child in tree.children]) | |
return 1+s | |
end | |
end | |
d=2 | |
n=800 | |
data = Set{Vector{Float64}}() | |
data_array = rand(d,n) | |
for i=1:n | |
add!(data,data_array[:,i]) | |
end | |
@time tree = BallTree.make_ball_tree(data); | |
n_queries = 1000 | |
queries = [rand(d) for i in 1:n_queries] | |
b = @time [BallTree.nearest_neighbor_linsearch(data,q) for q in queries] | |
a = @time [BallTree.query_ball_tree(tree,q) for q in queries] | |
println("depth: $(BallTree.depth(tree)))") | |
println("nodes: $(BallTree.nodes(tree)))") | |
println(a==b) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment