Created
March 15, 2019 04:25
-
-
Save utdemir/643f12f7ba394fa1ab70d4cdaa199ed7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Main where | |
import Data.List | |
import Data.Tree | |
data Elem a = Elem { elemVec :: [Double], elemData :: a } | |
deriving (Show) | |
data KDTree a | |
= Branch Int Double Int (KDTree a) (KDTree a) | |
| Leaf Int [Elem a] | |
deriving (Show) | |
getCount :: KDTree a -> Int | |
getCount (Branch _ _ c _ _) = c | |
getCount (Leaf c _) = c | |
toTree :: Show a => KDTree a -> Tree String | |
toTree (Branch dim med c left right) = | |
Node | |
("Branch " ++ show c ++ " " ++ show (dim, med)) | |
[ toTree left | |
, toTree right | |
] | |
toTree (Leaf c xs) = | |
Node | |
("Leaf " ++ show c) | |
[ Node (show x) [] | |
| x <- xs | |
] | |
belowUpperBound :: Int -> [a] -> Bool | |
belowUpperBound _ [] = True | |
belowUpperBound 0 _ = False | |
belowUpperBound n (_:xs) = belowUpperBound (n-1) xs | |
mkKDTree :: [Elem a] -> KDTree a | |
mkKDTree xs = | |
let num = length $ elemVec (head xs) | |
in go num 0 xs | |
where | |
go :: Int -> Int -> [Elem a] -> KDTree a | |
go num cur xs | belowUpperBound 3 xs = Leaf (length xs) xs | |
| otherwise = | |
let m = median (map (\i -> elemVec i !! cur) xs) | |
left = go num ((cur + 1) `mod` num) [i | i <- xs, elemVec i !! cur < m ] | |
right = go num ((cur + 1) `mod` num) [i | i <- xs, elemVec i !! cur >= m ] | |
in Branch | |
cur | |
m | |
(getCount left + getCount right) | |
left | |
right | |
lookupKDT :: [Double] -> Double -> KDTree a -> [Elem a] | |
lookupKDT origin radii (Leaf _ xs) | |
= [ Elem vec d | |
| Elem vec d <- xs | |
, dist origin vec <= radii | |
] | |
lookupKDT origin radii (Branch dim med _ left right) | |
= let dist = origin !! dim - med | |
in if abs dist >= radii | |
then lookupKDT origin radii (if dist < 0 then left else right) | |
else lookupKDT origin radii left ++ lookupKDT origin radii right | |
-- aggrKDT :: [Double] -> Double -> KDTree a -> Int | |
-- aggrKDT origin radii (Leaf _ xs) | |
-- = length [ Elem vec d | |
-- | Elem vec d <- xs | |
-- , dist origin vec <= radii | |
-- ] | |
-- aggrKDT origin radii (Branch dim med c left right) | |
-- = let dist = origin !! dim - med | |
-- in if abs dist >= radii | |
-- then aggrKDT origin radii (if dist < 0 then left else right) | |
-- else aggrKDT origin radii left ++ lookupKDT origin radii right | |
dist :: [Double] -> [Double] -> Double | |
dist p1 p2 = sqrt $ sum $ map (\i -> i ^^ 2) (zipWith (-) p1 p2) | |
median :: [Double] -> Double | |
median xs = | |
let sorted = sort xs | |
mid = length sorted `div` 2 | |
in sorted !! mid | |
main :: IO () | |
main = do | |
let ls = [ Elem [x, y] () | |
| x <- [0..3] | |
, y <- [0..3] | |
] | |
tr = mkKDTree ls | |
print $ ls | |
putStrLn $ drawTree $ toTree tr | |
mapM_ print $ lookupKDT [2.5, 2.5] 1 tr |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment