Skip to content

Instantly share code, notes, and snippets.

@Cedev
Created January 6, 2015 18:35
Show Gist options
  • Save Cedev/745e9c58edae4d546cae to your computer and use it in GitHub Desktop.
Save Cedev/745e9c58edae4d546cae to your computer and use it in GitHub Desktop.
General lazy tree unfold with no MonadFix and delayed path compression
import Data.Tree hiding (unfoldTreeM_BF, unfoldForestM_BF)
import Data.Traversable
import Prelude hiding (sequence)
import Control.Monad.Free
import Data.Functor.Identity
unfoldTreeM_BF :: Monad m => (b->m (a, [b])) -> b -> m (Tree a)
unfoldTreeM_BF f = (>>= return . getPure) . unfoldFreeM_BF f . Pure
unfoldForestM_BF :: Monad m => (b->m (a, [b])) -> [b] -> m [Tree a]
unfoldForestM_BF f = (>>= return . map getPure . getFree) . unfoldFreeM_BF f . Free . map Pure
unfoldFreeM_BF :: (Monad m) => (b->m (a, [b])) -> Free [] b -> m (Free [] (Tree a))
unfoldFreeM_BF f (Free []) = return (Free [])
unfoldFreeM_BF f seeds = do
level <- sequence . fmap f $ seeds
let (compressed, decompress) = compress (fmap snd level)
deeper <- unfoldFreeM_BF f compressed
let forests = decompress deeper
return $ zipWithIrrefutable Node (fmap fst level) forests
-- Two stage compression. Delays compression of paths for binds until just before compressing the next binds.
compress :: Free [] [b] -> (Free [] b, Free [] a -> Free [] [a])
compress xs = let (xs' , dxs' ) = compressFreeList xs
(xs'', dxs'') = bindFreeInvertible xs'
in (xs'', dxs' . dxs'')
{-
-- Strict path compression isn't lazy enough
compressFreeList :: Free [] b -> (Free [] b, Free [] a -> Free [] a)
compressFreeList (Pure x) = (Pure x, id)
compressFreeList (Free xs) = wrapList . compressList . map compressFreeList $ xs
where
compressList = foldr k ([], const [])
k (Free [], dx) (xs, dxs) = ( xs, \ xs -> dx (Free []):dxs xs)
k (x , dx) (xs, dxs) = (x:xs, \(x:xs) -> dx x :dxs xs)
wrapList ([x], dxs) = (x, \x -> Free (dxs [x]))
wrapList (xs , dxs) = (Free xs, \(Free xs) -> Free (dxs xs ))
-}
-- Extremely lazy path compression.
compressFreeList :: Free [] b -> (Free [] b, Free [] a -> Free [] a)
compressFreeList (Pure x) = (Pure x, id)
compressFreeList (Free xs) = wrapList . compressList . map compressFreeList $ xs
where
compressList = foldr k ([], const [])
k ~(x,dx) ~(xs', dxs) = (x', dxs')
where
x' = case x of
Free [] -> xs'
otherwise -> x:xs'
dxs' cxs = dx x'':dxs xs''
where
x'' = case x of
Free [] -> Free []
otherwise -> head cxs
xs'' = case x of
Free [] -> cxs
otherwise -> tail cxs
-- `wrapList` and the decompression functions could be one step lazier.
wrapList ([x], dxs) = (x, \x -> Free (dxs [x]))
wrapList (xs , dxs) = (Free xs, \(Free xs) -> Free (dxs xs ))
-- could be written in general for any Traversable t instead of for []
-- The decompression functions could all be lazier.
bindFreeInvertible :: Free [] ([] b) -> (Free [] b, Free [] a -> Free [] ([] a))
bindFreeInvertible (Pure xs) = (Free (map Pure xs), \(Free ps) -> Pure (map getPure ps))
bindFreeInvertible (Free xs) = wrapList . rebuildList . map bindFreeInvertible $ xs
where
rebuildList = foldr k ([], const [])
k ~(x,dx) ~(xs, dxs) = (x:xs, \(x:xs) -> dx x:dxs xs)
wrapList (xs, dxs) = (Free xs, \(Free xs) -> Free (dxs xs))
-- Single stage compression. Equivalent to using `bindFreeInvertible` before `compressFreeList`
compress' :: Free [] [b] -> (Free [] b, Free [] a -> Free [] [a])
compress' (Pure [x]) = (Pure x, \(Pure x) -> Pure [x])
compress' (Pure xs ) = (Free (map Pure xs), \(Free ps) -> Pure (map getPure ps))
compress' (Free xs) = wrapList . compressList . map compress $ xs
where
compressList = foldr k ([], const [])
k ~(x,dx) ~(xs', dxs) = (x', dxs')
where
x' = case x of
Free [] -> xs'
otherwise -> x:xs'
dxs' cxs = dx x'':dxs xs''
where
x'' = case x of
Free [] -> Free []
otherwise -> head cxs
xs'' = case x of
Free [] -> cxs
otherwise -> tail cxs
{-
-- Strict path compression may not be lazy enough
compressList = foldr k ([], const [])
k (Free [],dx) (xs', dxs) = (xs' , \xs -> dx (Free []):dxs xs)
k (x,dx) (xs', dxs) = (x:xs', \(x:xs) -> dx x :dxs xs)
-}
wrapList ([x], dxs) = (x, \x -> Free (dxs [x]))
wrapList (xs , dxs) = (Free xs, \(Free xs) -> Free (dxs xs ))
getFree ~(Free xs) = xs
getPure ~(Pure x) = x
class Functor f => Traceable f where
zipWithIrrefutable :: (a -> b -> c) -> f a -> f b -> f c
instance Traceable [] where
zipWithIrrefutable f [] ys = []
zipWithIrrefutable f (x:xs) ~(y:ys) = f x y : zipWithIrrefutable f xs ys
{-
instance (Traceable f, Traceable g) => Traceable (Compose f g) where
zipWithIrrefutable f (Compose xs) (Compose ys) =
Compose (zipWithIrrefutable (zipWithIrrefutable f) xs ys)
-}
instance (Traceable f) => Traceable (Free f) where
zipWithIrrefutable f (Pure x) ~(Pure y ) = Pure (f x y)
zipWithIrrefutable f (Free xs) ~(Free ys) =
Free (zipWithIrrefutable (zipWithIrrefutable f) xs ys)
{-
isEmpty :: Foldable f => f a -> Bool
isEmpty = foldr (\_ _ -> False) True
-}
{-
trace :: [[a]] -> [b] -> [[b]]
trace [] ys = []
trace (xs:xxs) ys =
let (ys', rem) = takeRemainder xs ys
in ys':trace xxs rem
where
takeRemainder [] ys = ([], ys)
takeRemainder (x:xs) ~(y:ys) =
let ( ys', rem) = takeRemainder xs ys
in (y:ys', rem)
-}
--------------------------------- Examples
mkUnary :: Int -> (Int, [Int])
mkUnary x = (x, [x+1])
mkBinary :: Int -> (Int, [Int])
mkBinary x = (x, [x+1,x+2])
mkTrinary :: Int -> (Int, [Int])
mkTrinary x = (x, [x+1,x+2,x+3])
mkNNary :: Int -> Int -> (Int, [Int])
mkNNary n x = (x, map (x+) [1..n])
mkInfinitary :: Int -> (Int, [Int])
mkInfinitary x = (x, [x+1..])
mkDepths :: Int -> IO (Int, [Int])
mkDepths d = do
print d
return (d, [d+1, d+1])
mkFiltered :: (Monad m) => (b -> Bool) -> (b -> m (a, [b])) -> (b -> m (a, [b]))
mkFiltered p f x = do
(a, bs) <- f x
return (a, filter p bs)
unfoldTreeDF f = runIdentity . unfoldTreeM (Identity . f)
unfoldTreeBF f = runIdentity . unfoldTreeM_BF (Identity . f)
takeWhileTree :: (a -> Bool) -> Tree a -> Tree a
takeWhileTree p (Node label branches) = Node label (takeWhileForest p branches)
takeWhileForest :: (a -> Bool) -> [Tree a] -> [Tree a]
takeWhileForest p = map (takeWhileTree p) . takeWhile (p . rootLabel)
unary = takeWhileTree (<= 3) (unfoldTree mkUnary 0)
unaryDF = takeWhileTree (<= 3) (unfoldTreeDF mkUnary 0)
unaryBF = takeWhileTree (<= 3) (unfoldTreeBF mkUnary 0)
binary = takeWhileTree (<= 3) (unfoldTree mkBinary 0)
binaryDF = takeWhileTree (<= 3) (unfoldTreeDF mkBinary 0)
binaryBF = takeWhileTree (<= 3) (unfoldTreeBF mkBinary 0)
infinitary = takeWhileTree (<= 3) (unfoldTree mkInfinitary 0)
infinitaryDF = takeWhileTree (<= 3) (unfoldTreeDF mkInfinitary 0)
infinitaryBF = takeWhileTree (<= 3) (unfoldTreeBF mkInfinitary 0)
trinaryBF = takeWhileTree (<= 3) (unfoldTreeBF mkTrinary 0)
tenaryBF = takeWhileTree (<= 3) (unfoldTreeBF (mkNNary 10) 0)
binaryDepths = unfoldTreeM_BF (mkFiltered (<= 2) mkDepths) 0
main = do
putStrLn . drawTree . fmap show $ unary
putStrLn . drawTree . fmap show $ unaryDF
putStrLn . drawTree . fmap show $ unaryBF
putStrLn . drawTree . fmap show $ binary
putStrLn . drawTree . fmap show $ binaryDF
putStrLn . drawTree . fmap show $ binaryBF
putStrLn . drawTree . fmap show $ infinitary
putStrLn . drawTree . fmap show $ infinitaryDF
putStrLn . drawTree . fmap show $ infinitaryBF
putStrLn . drawTree . fmap show $ trinaryBF
putStrLn . drawTree . fmap show $ tenaryBF
print . until (null . subForest) (last . subForest) $ runIdentity $ flip unfoldTreeM_BF 0 (\x -> return (x, if x > 5 then [] else replicate 9 x ++ [x+1]))
print . until (null . subForest) (last . subForest) $ runIdentity $ flip unfoldTreeM_BF 0 (\x -> return (x, if x > 5 then [] else replicate 9 6 ++ [x+1]))
print . until (null . subForest) (last . subForest) $ runIdentity $ flip unfoldTreeM 0 (\x -> return (x, if x > 5 then [] else replicate 10 (x+1)))
print . until (null . subForest) (last . subForest) $ runIdentity $ flip unfoldTreeM_BF 0 (\x -> return (x, if x > 5 then [] else replicate 10 (x+1)))
b <- binaryDepths
putStrLn . drawTree . fmap show $ b
return ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment