Skip to content

Instantly share code, notes, and snippets.

@utdemir
Created March 15, 2019 04:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save utdemir/643f12f7ba394fa1ab70d4cdaa199ed7 to your computer and use it in GitHub Desktop.
Save utdemir/643f12f7ba394fa1ab70d4cdaa199ed7 to your computer and use it in GitHub Desktop.
module Main where
import Data.List
import Data.Tree
data Elem a = Elem { elemVec :: [Double], elemData :: a }
deriving (Show)
data KDTree a
= Branch Int Double Int (KDTree a) (KDTree a)
| Leaf Int [Elem a]
deriving (Show)
getCount :: KDTree a -> Int
getCount (Branch _ _ c _ _) = c
getCount (Leaf c _) = c
toTree :: Show a => KDTree a -> Tree String
toTree (Branch dim med c left right) =
Node
("Branch " ++ show c ++ " " ++ show (dim, med))
[ toTree left
, toTree right
]
toTree (Leaf c xs) =
Node
("Leaf " ++ show c)
[ Node (show x) []
| x <- xs
]
belowUpperBound :: Int -> [a] -> Bool
belowUpperBound _ [] = True
belowUpperBound 0 _ = False
belowUpperBound n (_:xs) = belowUpperBound (n-1) xs
mkKDTree :: [Elem a] -> KDTree a
mkKDTree xs =
let num = length $ elemVec (head xs)
in go num 0 xs
where
go :: Int -> Int -> [Elem a] -> KDTree a
go num cur xs | belowUpperBound 3 xs = Leaf (length xs) xs
| otherwise =
let m = median (map (\i -> elemVec i !! cur) xs)
left = go num ((cur + 1) `mod` num) [i | i <- xs, elemVec i !! cur < m ]
right = go num ((cur + 1) `mod` num) [i | i <- xs, elemVec i !! cur >= m ]
in Branch
cur
m
(getCount left + getCount right)
left
right
lookupKDT :: [Double] -> Double -> KDTree a -> [Elem a]
lookupKDT origin radii (Leaf _ xs)
= [ Elem vec d
| Elem vec d <- xs
, dist origin vec <= radii
]
lookupKDT origin radii (Branch dim med _ left right)
= let dist = origin !! dim - med
in if abs dist >= radii
then lookupKDT origin radii (if dist < 0 then left else right)
else lookupKDT origin radii left ++ lookupKDT origin radii right
-- aggrKDT :: [Double] -> Double -> KDTree a -> Int
-- aggrKDT origin radii (Leaf _ xs)
-- = length [ Elem vec d
-- | Elem vec d <- xs
-- , dist origin vec <= radii
-- ]
-- aggrKDT origin radii (Branch dim med c left right)
-- = let dist = origin !! dim - med
-- in if abs dist >= radii
-- then aggrKDT origin radii (if dist < 0 then left else right)
-- else aggrKDT origin radii left ++ lookupKDT origin radii right
dist :: [Double] -> [Double] -> Double
dist p1 p2 = sqrt $ sum $ map (\i -> i ^^ 2) (zipWith (-) p1 p2)
median :: [Double] -> Double
median xs =
let sorted = sort xs
mid = length sorted `div` 2
in sorted !! mid
main :: IO ()
main = do
let ls = [ Elem [x, y] ()
| x <- [0..3]
, y <- [0..3]
]
tr = mkKDTree ls
print $ ls
putStrLn $ drawTree $ toTree tr
mapM_ print $ lookupKDT [2.5, 2.5] 1 tr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment