Skip to content

Instantly share code, notes, and snippets.

@suzusuzu
Created December 13, 2020 16:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suzusuzu/623321e17ad93a190b0dac0cc9e4083e to your computer and use it in GitHub Desktop.
Save suzusuzu/623321e17ad93a190b0dac0cc9e4083e to your computer and use it in GitHub Desktop.
Navigable Small World(NSW)
using Random
using LinearAlgebra
using DataStructures
using Base
mutable struct Node
data
friend::Set{Node}
end
function Base.show(io::IO, n::Node)
friend_str = join(map((x) -> string(x.data), collect(n.friend)), ", ")
println(io, "data: {", n.data, "}, friend: {", friend_str, "}")
end
function knn_search(nodes::Vector{Node}, q, m::Int, k::Int)
visited_set = Set()
canditates = PriorityQueue()
result = PriorityQueue()
for _ in 1:m
if length(visited_set) == length(nodes)
break
end
tmp_result = Vector{Node}()
while true
ri = rand(1:length(nodes), 1)[1]
node = nodes[ri]
if !in(node, visited_set)
push!(visited_set, node)
enqueue!(canditates, node=>norm(node.data .- q))
push!(tmp_result, node)
break
end
end
while true
if length(canditates) == 0
break
end
c = dequeue!(canditates)
c_d = norm(c.data .- q)
result_collect = collect(result)
if length(result_collect) >= k
n = result_collect[k]
if c_d >= n[2]
break
end
end
for node in c.friend
if !in(node, visited_set)
push!(visited_set, node)
enqueue!(canditates, node=>norm(node.data .- q))
push!(tmp_result, node)
end
end
end
for node in tmp_result
enqueue!(result, node=>norm(node.data .- q))
end
end
result_collect = collect(result)[1:k]
result_v = Vector()
for r in result_collect
push!(result_v, r[1])
end
return result_v
end
function nearest_neighbor_insert(nodes, new_node, f, w)
neighbors = knn_search(nodes, new_node.data, w, f)
for node in neighbors
push!(node.friend, new_node)
push!(new_node.friend, node)
end
end
function nsw_build(nodes, f, w)
for new_node in nodes
nearest_neighbor_insert(nodes, new_node, f, w)
end
end
function greedy_searh(nodes, q)
ri = rand(1:length(nodes), 1)[1]
near_node = nodes[ri]
min_d = norm(near_node.data .- q)
while true
break_flg = true
for node in near_node.friend
d = norm(node.data - q)
if d < min_d
min_d = d
near_node = node
break_flg = false
end
end
if break_flg
break
end
end
return near_node
end
Random.seed!(1234)
n = 1000 # number of node
dim = 2 # dimension
nodes = [Node(rand(dim), Set()) for i in 1:1000]
nsw_build(nodes, 10, 10)
q = rand(dim)
println("query: ", q)
res = greedy_searh(nodes, q)
println("result: ", res.data)
l2 = norm(res.data .- q)
println("l2 distance: ", l2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment