Last active April 10, 2020 10:31
Probabilistic Programming using Free Selective Functors let us statically analyse which effects are necessary.
module SelectiveProb where
import Control.Selective.Free
import Control.Selective
import System.Random.MWC
import System.Random.MWC.Distributions
import Data.Bool
import Data.Bifunctor (bimap)
import qualified Data.Vector as V
data Primitives a where
Uniform :: [a] -> Primitives a
Categorical :: [(a, Double)] -> Primitives a
Normal :: Double -> Double -> Primitives Double
Bernoulli :: a -> a -> Double -> Primitives a
Beta :: Double -> Double -> Primitives Double
Gamma :: Double -> Double -> Primitives Double
type Dist a = Select Primitives a
type Prob = Double
data Coin = Heads | Tails
deriving (Show, Eq, Bounded, Enum)
-- Guard function used in McCarthy's conditional
-- | It provides information about the outcome of testing @p@ on some input @a@,
-- encoded in terms of the coproduct injections without losing the input
-- @a@ itself.
grdS :: Applicative f => f (a -> Bool) -> f a -> f (Either a a)
grdS f a = selector <$> applyF f (dup <$> a)
dup x = (x, x)
applyF fab faa = bimap <$> fab <*> pure id <*> faa
selector (b, x) = bool (Right x) (Left x) b
-- | McCarthy's conditional, denoted p -> f,g is a well-known functional
-- combinator, which suggests that, to reason about conditionals, one may
-- seek help in the algebra of coproducts.
-- This combinator is very similar to the very nature of the 'select'
-- operator and benefits from a series of properties and laws.
condS :: Selective f => f (b -> Bool) -> f (b -> c) -> f (b -> c) -> f b -> f c
condS p f g = (\r -> branch r f g) . grdS p
-- | Throw @n@ coins
prog :: Int -> Dist [Coin]
prog 0 = pure []
prog n =
let toss = liftSelect $ Bernoulli Heads Tails 0.5
in condS (pure (== Heads))
(flip (:) <$> prog (n - 1))
(pure (: []))
-- | This models a simple board game where, at each turn,
-- two dice are thrown and, if the value of the two dice is equal,
-- the face of the third dice is equal to the other dice,
-- otherwise the third die is thrown and one piece moves
-- the number of squares equal to the sum of all the dice.
diceThrow :: Dist Int
diceThrow =
condS (pure $ uncurry (==))
(pure (\(a, _) -> a + a + a))
((\c (a, b) -> a + b + c) <$> die) -- Speculative dice throw
((,) <$> die <*> die) -- Parallel dice throw
die :: Dist Int
die = liftSelect $ Uniform [1..6]
-- | Infering the weight of a coin.
-- The coin is fair with probability 0.8 and biased with probability 0.2.
weight :: Dist Prob
weight =
ifS ((== True) <$> liftSelect (Bernoulli True False 0.8))
(pure 0.5)
(liftSelect $ Beta 5 1)
sample :: Dist a -> Int -> Dist [a]
sample r n = sequenceA (replicate n r)
runToIO :: Dist a -> IO a
runToIO = runSelect interpret
interpret (Uniform l) = do
c <- createSystemRandom
i <- uniformR (0, length l - 1) c
return (l !! i)
interpret (Categorical l) = do
c <- createSystemRandom
i <- categorical (V.fromList . map snd $ l) c
return (fst $ l !! i)
interpret (Normal x y) = do
c <- createSystemRandom
normal x y c
interpret (Bernoulli h t x) = do
c <- createSystemRandom
bool h t <$> bernoulli x c
interpret (Beta x y) = do
c <- createSystemRandom
beta x y c
interpret (Gamma x y) = do
c <- createSystemRandom
gamma x y c
