Created
December 13, 2020 16:43
-
-
Save suzusuzu/623321e17ad93a190b0dac0cc9e4083e to your computer and use it in GitHub Desktop.
Navigable Small World(NSW)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Random | |
using LinearAlgebra | |
using DataStructures | |
using Base | |
mutable struct Node | |
data | |
friend::Set{Node} | |
end | |
function Base.show(io::IO, n::Node) | |
friend_str = join(map((x) -> string(x.data), collect(n.friend)), ", ") | |
println(io, "data: {", n.data, "}, friend: {", friend_str, "}") | |
end | |
function knn_search(nodes::Vector{Node}, q, m::Int, k::Int) | |
visited_set = Set() | |
canditates = PriorityQueue() | |
result = PriorityQueue() | |
for _ in 1:m | |
if length(visited_set) == length(nodes) | |
break | |
end | |
tmp_result = Vector{Node}() | |
while true | |
ri = rand(1:length(nodes), 1)[1] | |
node = nodes[ri] | |
if !in(node, visited_set) | |
push!(visited_set, node) | |
enqueue!(canditates, node=>norm(node.data .- q)) | |
push!(tmp_result, node) | |
break | |
end | |
end | |
while true | |
if length(canditates) == 0 | |
break | |
end | |
c = dequeue!(canditates) | |
c_d = norm(c.data .- q) | |
result_collect = collect(result) | |
if length(result_collect) >= k | |
n = result_collect[k] | |
if c_d >= n[2] | |
break | |
end | |
end | |
for node in c.friend | |
if !in(node, visited_set) | |
push!(visited_set, node) | |
enqueue!(canditates, node=>norm(node.data .- q)) | |
push!(tmp_result, node) | |
end | |
end | |
end | |
for node in tmp_result | |
enqueue!(result, node=>norm(node.data .- q)) | |
end | |
end | |
result_collect = collect(result)[1:k] | |
result_v = Vector() | |
for r in result_collect | |
push!(result_v, r[1]) | |
end | |
return result_v | |
end | |
function nearest_neighbor_insert(nodes, new_node, f, w) | |
neighbors = knn_search(nodes, new_node.data, w, f) | |
for node in neighbors | |
push!(node.friend, new_node) | |
push!(new_node.friend, node) | |
end | |
end | |
function nsw_build(nodes, f, w) | |
for new_node in nodes | |
nearest_neighbor_insert(nodes, new_node, f, w) | |
end | |
end | |
function greedy_searh(nodes, q) | |
ri = rand(1:length(nodes), 1)[1] | |
near_node = nodes[ri] | |
min_d = norm(near_node.data .- q) | |
while true | |
break_flg = true | |
for node in near_node.friend | |
d = norm(node.data - q) | |
if d < min_d | |
min_d = d | |
near_node = node | |
break_flg = false | |
end | |
end | |
if break_flg | |
break | |
end | |
end | |
return near_node | |
end | |
Random.seed!(1234) | |
n = 1000 # number of node | |
dim = 2 # dimension | |
nodes = [Node(rand(dim), Set()) for i in 1:1000] | |
nsw_build(nodes, 10, 10) | |
q = rand(dim) | |
println("query: ", q) | |
res = greedy_searh(nodes, q) | |
println("result: ", res.data) | |
l2 = norm(res.data .- q) | |
println("l2 distance: ", l2) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment