Skip to content

Instantly share code, notes, and snippets.

@bolt12
Last active April 10, 2020 10:31
Show Gist options
  • Save bolt12/07e7a397a2e1dd460a7b05ee242f5f33 to your computer and use it in GitHub Desktop.
Save bolt12/07e7a397a2e1dd460a7b05ee242f5f33 to your computer and use it in GitHub Desktop.
Probabilistic Programming using Free Selective Functors let us statically analyse which effects are necessary.
{-# LANGUAGE GADTs #-}
module SelectiveProb where
import Control.Selective.Free
import Control.Selective
import System.Random.MWC
import System.Random.MWC.Distributions
import Data.Bool
import Data.Bifunctor (bimap)
import qualified Data.Vector as V
data Primitives a where
Uniform :: [a] -> Primitives a
Categorical :: [(a, Double)] -> Primitives a
Normal :: Double -> Double -> Primitives Double
Bernoulli :: a -> a -> Double -> Primitives a
Beta :: Double -> Double -> Primitives Double
Gamma :: Double -> Double -> Primitives Double
type Dist a = Select Primitives a
type Prob = Double
data Coin = Heads | Tails
deriving (Show, Eq, Bounded, Enum)
-- Guard function used in McCarthy's conditional
-- | It provides information about the outcome of testing @p@ on some input @a@,
-- encoded in terms of the coproduct injections without losing the input
-- @a@ itself.
grdS :: Applicative f => f (a -> Bool) -> f a -> f (Either a a)
grdS f a = selector <$> applyF f (dup <$> a)
where
dup x = (x, x)
applyF fab faa = bimap <$> fab <*> pure id <*> faa
selector (b, x) = bool (Right x) (Left x) b
-- | McCarthy's conditional, denoted p -> f,g is a well-known functional
-- combinator, which suggests that, to reason about conditionals, one may
-- seek help in the algebra of coproducts.
--
-- This combinator is very similar to the very nature of the 'select'
-- operator and benefits from a series of properties and laws.
condS :: Selective f => f (b -> Bool) -> f (b -> c) -> f (b -> c) -> f b -> f c
condS p f g = (\r -> branch r f g) . grdS p
-- | Throw @n@ coins
prog :: Int -> Dist [Coin]
prog 0 = pure []
prog n =
let toss = liftSelect $ Bernoulli Heads Tails 0.5
in condS (pure (== Heads))
(flip (:) <$> prog (n - 1))
(pure (: []))
toss
-- | This models a simple board game where, at each turn,
-- two dice are thrown and, if the value of the two dice is equal,
-- the face of the third dice is equal to the other dice,
-- otherwise the third die is thrown and one piece moves
-- the number of squares equal to the sum of all the dice.
diceThrow :: Dist Int
diceThrow =
condS (pure $ uncurry (==))
(pure (\(a, _) -> a + a + a))
((\c (a, b) -> a + b + c) <$> die) -- Speculative dice throw
((,) <$> die <*> die) -- Parallel dice throw
die :: Dist Int
die = liftSelect $ Uniform [1..6]
-- | Infering the weight of a coin.
--
-- The coin is fair with probability 0.8 and biased with probability 0.2.
weight :: Dist Prob
weight =
ifS ((== True) <$> liftSelect (Bernoulli True False 0.8))
(pure 0.5)
(liftSelect $ Beta 5 1)
sample :: Dist a -> Int -> Dist [a]
sample r n = sequenceA (replicate n r)
runToIO :: Dist a -> IO a
runToIO = runSelect interpret
where
interpret (Uniform l) = do
c <- createSystemRandom
i <- uniformR (0, length l - 1) c
return (l !! i)
interpret (Categorical l) = do
c <- createSystemRandom
i <- categorical (V.fromList . map snd $ l) c
return (fst $ l !! i)
interpret (Normal x y) = do
c <- createSystemRandom
normal x y c
interpret (Bernoulli h t x) = do
c <- createSystemRandom
bool h t <$> bernoulli x c
interpret (Beta x y) = do
c <- createSystemRandom
beta x y c
interpret (Gamma x y) = do
c <- createSystemRandom
gamma x y c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment