Skip to content

Instantly share code, notes, and snippets.

@sjoerdvisscher
Last active June 6, 2020 11:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sjoerdvisscher/9e605825bb4489ff6c6ebe99af9e2e47 to your computer and use it in GitHub Desktop.
Save sjoerdvisscher/9e605825bb4489ff6c6ebe99af9e2e47 to your computer and use it in GitHub Desktop.
A distribution monad transformer
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Trans.DistT where
import Data.List.NonEmpty (NonEmpty(..))
import Control.Monad.Trans.Class
import Control.Monad.Trans.Free
import Control.Arrow (Arrow(first))
import Control.Applicative (liftA2, Alternative(..))
-- `cataMFreeT` and `cataFreeT` should really be in Control.Monad.Trans.Free
cataMFreeT :: (Monad m, Functor f) => (a -> m r) -> (f (m r) -> m r) -> FreeT f m a -> m r
cataMFreeT f g = go where
go (FreeT m) = do
e <- m
case e of
Pure a -> f a
Free fr -> g (fmap go fr)
cataFreeT :: (Monad m, Traversable f) => (a -> r) -> (f r -> r) -> FreeT f m a -> m r
cataFreeT f g = cataMFreeT (pure . f) (fmap g . sequence)
data Choose p a = Choose p a a deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
-- | The distribution monad transformer
type DistT p = FreeT (Choose p)
cons :: (Fractional p, Monad m) => NonEmpty (p, m a) -> DistT p m a
cons = cons' 1
where
cons' :: (Fractional p, Monad m) => p -> NonEmpty (p, m a) -> DistT p m a
cons' _ ((_, ma) :| []) = lift ma
cons' q ((p, ma) :| pa : pas) = wrap (Choose (q*p) (lift ma) (cons' (q / (1 - p)) (pa :| pas)))
decons :: (Num p, Monad m, Alternative f) => DistT p m a -> m (f (p, a))
decons = fmap ($ 1) . cataFreeT
(\a p -> pure (p, a))
(\(Choose p d1 d2) q -> d1 (q * p) <|> d2 (q * (1 - p)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment