Last active
June 6, 2020 11:22
-
-
Save sjoerdvisscher/9e605825bb4489ff6c6ebe99af9e2e47 to your computer and use it in GitHub Desktop.
A distribution monad transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TupleSections #-} | |
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE RankNTypes #-} | |
module Control.Monad.Trans.DistT where | |
import Data.List.NonEmpty (NonEmpty(..)) | |
import Control.Monad.Trans.Class | |
import Control.Monad.Trans.Free | |
import Control.Arrow (Arrow(first)) | |
import Control.Applicative (liftA2, Alternative(..)) | |
-- `cataMFreeT` and `cataFreeT` should really be in Control.Monad.Trans.Free | |
cataMFreeT :: (Monad m, Functor f) => (a -> m r) -> (f (m r) -> m r) -> FreeT f m a -> m r | |
cataMFreeT f g = go where | |
go (FreeT m) = do | |
e <- m | |
case e of | |
Pure a -> f a | |
Free fr -> g (fmap go fr) | |
cataFreeT :: (Monad m, Traversable f) => (a -> r) -> (f r -> r) -> FreeT f m a -> m r | |
cataFreeT f g = cataMFreeT (pure . f) (fmap g . sequence) | |
data Choose p a = Choose p a a deriving (Eq, Ord, Show, Functor, Foldable, Traversable) | |
-- | The distribution monad transformer | |
type DistT p = FreeT (Choose p) | |
cons :: (Fractional p, Monad m) => NonEmpty (p, m a) -> DistT p m a | |
cons = cons' 1 | |
where | |
cons' :: (Fractional p, Monad m) => p -> NonEmpty (p, m a) -> DistT p m a | |
cons' _ ((_, ma) :| []) = lift ma | |
cons' q ((p, ma) :| pa : pas) = wrap (Choose (q*p) (lift ma) (cons' (q / (1 - p)) (pa :| pas))) | |
decons :: (Num p, Monad m, Alternative f) => DistT p m a -> m (f (p, a)) | |
decons = fmap ($ 1) . cataFreeT | |
(\a p -> pure (p, a)) | |
(\(Choose p d1 d2) q -> d1 (q * p) <|> d2 (q * (1 - p))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment