Skip to content

Instantly share code, notes, and snippets.

@dergraf
Last active December 19, 2015 23:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dergraf/6038721 to your computer and use it in GitHub Desktop.
Save dergraf/6038721 to your computer and use it in GitHub Desktop.
incomplete kdtree implementation in erlang.
-module(kdtree).
-export([new/1, search/2, distance/2, test/0, test/3]).
new([]) -> [];
new(PointList) ->
K = size(lists:nth(1, PointList)),
kdtree(K, PointList, 0).
kdtree(_, [], _) -> [];
kdtree(K, PointList, Depth) ->
%% Select Dimension based on Depth so that Dimension cycles through all valid values
Dimension = (Depth rem K) + 1,
%% Sort PointList and choose median as pivot element
SortedPointList = lists:sort(fun(A, B) -> element(Dimension, A) =< element(Dimension, B) end, PointList),
L = length(PointList),
NewDepth = Depth + 1,
Median = (L div 2) + 1,
%% Create node and construct subtree
{lists:nth(Median, SortedPointList),
kdtree(K, lists:sublist(SortedPointList, Median-1), NewDepth),
kdtree(K, lists:sublist(SortedPointList, Median+1, L), NewDepth)}.
search(Tree, Point) when is_tuple(Point) ->
search(size(Point), Tree, Point, 0, undefined).
search(_, [], _, _, Best) -> Best;
search(K, {Location, _, _}=Tree, Point, Depth, undefined) ->
search(K, Tree, Point, Depth, Location);
search(K, {Location, Left, Right}, Point, Depth, Best) ->
NewBest1 =
case distance(Point, Location) < distance(Point, Best) of
true ->
Location;
false ->
Best
end,
Dimension = (Depth rem K) + 1,
case element(Dimension, Point) < element(Dimension, Location) of
true ->
NewBest2 = search(K, Left, Point, Depth + 1, NewBest1),
case distance_axis(Dimension, Location, Point) < distance(NewBest2, Point) of
true ->
search(K, Right, Point, Depth + 1, NewBest2);
false ->
NewBest2
end;
false ->
NewBest2 = search(K, Right, Point, Depth + 1, NewBest1),
case distance_axis(Dimension, Location, Point) < distance(NewBest2, Point) of
true ->
search(K, Left, Point, Depth + 1, NewBest2);
false ->
NewBest2
end
end.
distance(A, B) ->
%% squared distance between points A and B
distance(A, B, size(A), 0).
distance(_, _, 0, Distance) -> Distance;
distance(A, B, Dim, Distance) ->
distance(A, B, Dim-1, Distance + math:pow(element(Dim, A) - element(Dim, B), 2)).
distance_axis(Dim, Location, Point) ->
%% project point onto node axis
%% i.e. want to measure distance on axis orthogonal to current node's axis
distance(setelement(Dim, Point, element(Dim, Location)), Point).
test() ->
Dims = [1,2,3,4,5,6,7,8],
NrOfItems = [10,100,1000,10000],
[test(D, N, 100) || D <- Dims, N <- NrOfItems].
test(Dim, N, Iters) ->
RandomPoint = fun() ->
list_to_tuple([random:uniform(1000) || _ <- lists:seq(1, Dim)])
end,
PointList = [RandomPoint() || _ <- lists:seq(1,N)],
{TimeTreeSetup, Tree} = timer:tc(fun new/1, [PointList]),
%% find nearest neighbour linearly,
LinearSearch = fun(P) ->
lists:foldl(fun(Point, CurrentClosest) ->
case distance(P, Point) < distance(P, CurrentClosest) of
true -> Point;
false -> CurrentClosest
end
end, lists:nth(1, PointList), lists:sort(PointList))
end,
Results = lists:map(fun(_) ->
RandomSample = RandomPoint(),
{TimeLinear, P1} = timer:tc(LinearSearch, [RandomSample]),
{TimeKDTree, P2} = timer:tc(fun search/2, [Tree, RandomSample]),
D1 = distance(P1, RandomSample),
D1 = distance(P2, RandomSample), %% we get a badmatch if the two returned neighbours don't have the same distance
{TimeLinear, TimeKDTree}
end, lists:seq(1, Iters)),
{LinearTime, KDTreeTime} = lists:unzip(Results),
TotalTimeLinear = lists:sum(LinearTime) / Iters,
TotalTimeKDTree = lists:sum(KDTreeTime) / Iters,
io:format("~p-dimensional, ~p items, buildtime ~pus, NNS avg time linear ~pus, NNS avg time kdtree ~pus~n", [Dim, N, TimeTreeSetup,TotalTimeLinear, TotalTimeKDTree]).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment