Skip to content

Instantly share code, notes, and snippets.

@goretkin
Created March 17, 2014 16:55
Show Gist options
  • Save goretkin/9603371 to your computer and use it in GitHub Desktop.
Save goretkin/9603371 to your computer and use it in GitHub Desktop.
module BallTree
dist = (p,q)->norm(p-q)
macro argbest(metric,compare,candidates,query)
#returns index into candidates with best metric
local best_so_far_value = metric(candidates[1],query)
local best_so_far_index = 1
for i in 1:length(candidates)
local m = metric(candidates[i],query)
if compare(m,best_so_far_value)
best_so_far_index = i
best_so_far_value = m
end
end
return best_so_far_index
end
gr
function argbest_set(metric,compare,candidates,query)
#returns element of candidates with best metric
local best_so_far_value
local best_so_far
for c in candidates
best_so_far_value = metric(c,query)
best_so_far = c
break
end
for c in candidates
local m = metric(c,query)
if compare(m,best_so_far_value)
best_so_far = c
best_so_far_value = m
end
end
best_so_far
end
type BallTreeNode{P}
#P is type of points in BallTree
data::Set{P} #set of points contained within this ball.
radius::Float64
center::P
children::Set{BallTreeNode}
BallTreeNode() = new()
end
function randsample_set{P}(s::Set{P})
N = length(s)
r = rand(1:N)
i = 1
#emulate s[r], but Set is not indexable.
for y in s
if r==i
return y
end
i+=1
end
end
function split{P}(data::Set{P})
#return points that would be good exemplars for subtrees.
x = randsample_set(data)
a = argbest_set(dist,>,data,x)
b = argbest_set(dist,>,data,a)
(a,b)
end
function make_ball_tree{P}(data::Set{P})
tree =BallTreeNode{P}()
tree.data = data
tree.center = mean(tree.data)
tree.radius = dist(argbest_set(dist,>,data,tree.center),tree.center)
tree.children=Set{BallTreeNode}()
if length(data)>20
exemplars = split(data)
@assert(length(exemplars)==2) #binary tree
@assert(exemplars[1] != exemplars[2]) #must be unique
subtree1_data::Set{P} = Set{P}()
subtree2_data::Set{P} = Set{P}()
for d in data
if dist(d,exemplars[1]) < dist(d,exemplars[2])
add!(subtree1_data,d)
else
add!(subtree2_data,d)
end
end
add!(tree.children,make_ball_tree(subtree1))
add!(tree.children,make_ball_tree(subtree2))
end
tree
end
function nearest_neighbor_linsearch{P}(data::Set{P},query::P)
return argbest_set(dist,<,data,query)
end
type UpperBound{P}
value::Float64
best_point::P
UpperBound() = (new(Inf))
end
function query_ball_tree{P}(tree::BallTreeNode{P},query::P)
b = UpperBound{P}()
query_ball_tree!(tree,query,b)
return b.best_point
end
function query_ball_tree!{P}(tree::BallTreeNode{P},query::P,bound::UpperBound)
#Recursively depth-first search down tree and prune out branches that cannot yield a better value than what is in bound. Update bound if something better is found.
best_possible = dist(query,tree.center)-tree.radius
#if the query is inside the ball, then best_possible is negative.
if(best_possible > bound.value)
#don't explore this node further
return
end
if length(tree.children)==0
#leaf node
p = nearest_neighbor_linsearch(tree.data,query)
d = dist(p,query)
if bound.value > d
bound.value=d
bound.best_point=p
end
return
end
@assert(length(tree.children)==2)
children = [c for c in tree.children]
d1 = dist(children[1].center,query)
d2 = dist(children[2].center,query)
if d1<d2
query_ball_tree!(children[1],query,bound)
query_ball_tree!(children[2],query,bound)
else
query_ball_tree!(children[2],query,bound)
query_ball_tree!(children[1],query,bound)
end
end
function depth(tree::BallTreeNode)
if length(tree.children)==0
return 0
end
return 1 + max([depth(child) for child in tree.children])
end
function nodes(tree::BallTreeNode)
s = sum([nodes(child) for child in tree.children])
return 1+s
end
end
d=2
n=800
data = Set{Vector{Float64}}()
data_array = rand(d,n)
for i=1:n
add!(data,data_array[:,i])
end
@time tree = BallTree.make_ball_tree(data);
n_queries = 1000
queries = [rand(d) for i in 1:n_queries]
b = @time [BallTree.nearest_neighbor_linsearch(data,q) for q in queries]
a = @time [BallTree.query_ball_tree(tree,q) for q in queries]
println("depth: $(BallTree.depth(tree)))")
println("nodes: $(BallTree.nodes(tree)))")
println(a==b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment