Last active
April 10, 2020 10:31
-
-
Save bolt12/07e7a397a2e1dd460a7b05ee242f5f33 to your computer and use it in GitHub Desktop.
Probabilistic Programming using Free Selective Functors let us statically analyse which effects are necessary.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GADTs #-} | |
module SelectiveProb where | |
import Control.Selective.Free | |
import Control.Selective | |
import System.Random.MWC | |
import System.Random.MWC.Distributions | |
import Data.Bool | |
import Data.Bifunctor (bimap) | |
import qualified Data.Vector as V | |
data Primitives a where | |
Uniform :: [a] -> Primitives a | |
Categorical :: [(a, Double)] -> Primitives a | |
Normal :: Double -> Double -> Primitives Double | |
Bernoulli :: a -> a -> Double -> Primitives a | |
Beta :: Double -> Double -> Primitives Double | |
Gamma :: Double -> Double -> Primitives Double | |
type Dist a = Select Primitives a | |
type Prob = Double | |
data Coin = Heads | Tails | |
deriving (Show, Eq, Bounded, Enum) | |
-- Guard function used in McCarthy's conditional | |
-- | It provides information about the outcome of testing @p@ on some input @a@, | |
-- encoded in terms of the coproduct injections without losing the input | |
-- @a@ itself. | |
grdS :: Applicative f => f (a -> Bool) -> f a -> f (Either a a) | |
grdS f a = selector <$> applyF f (dup <$> a) | |
where | |
dup x = (x, x) | |
applyF fab faa = bimap <$> fab <*> pure id <*> faa | |
selector (b, x) = bool (Right x) (Left x) b | |
-- | McCarthy's conditional, denoted p -> f,g is a well-known functional | |
-- combinator, which suggests that, to reason about conditionals, one may | |
-- seek help in the algebra of coproducts. | |
-- | |
-- This combinator is very similar to the very nature of the 'select' | |
-- operator and benefits from a series of properties and laws. | |
condS :: Selective f => f (b -> Bool) -> f (b -> c) -> f (b -> c) -> f b -> f c | |
condS p f g = (\r -> branch r f g) . grdS p | |
-- | Throw @n@ coins | |
prog :: Int -> Dist [Coin] | |
prog 0 = pure [] | |
prog n = | |
let toss = liftSelect $ Bernoulli Heads Tails 0.5 | |
in condS (pure (== Heads)) | |
(flip (:) <$> prog (n - 1)) | |
(pure (: [])) | |
toss | |
-- | This models a simple board game where, at each turn, | |
-- two dice are thrown and, if the value of the two dice is equal, | |
-- the face of the third dice is equal to the other dice, | |
-- otherwise the third die is thrown and one piece moves | |
-- the number of squares equal to the sum of all the dice. | |
diceThrow :: Dist Int | |
diceThrow = | |
condS (pure $ uncurry (==)) | |
(pure (\(a, _) -> a + a + a)) | |
((\c (a, b) -> a + b + c) <$> die) -- Speculative dice throw | |
((,) <$> die <*> die) -- Parallel dice throw | |
die :: Dist Int | |
die = liftSelect $ Uniform [1..6] | |
-- | Infering the weight of a coin. | |
-- | |
-- The coin is fair with probability 0.8 and biased with probability 0.2. | |
weight :: Dist Prob | |
weight = | |
ifS ((== True) <$> liftSelect (Bernoulli True False 0.8)) | |
(pure 0.5) | |
(liftSelect $ Beta 5 1) | |
sample :: Dist a -> Int -> Dist [a] | |
sample r n = sequenceA (replicate n r) | |
runToIO :: Dist a -> IO a | |
runToIO = runSelect interpret | |
where | |
interpret (Uniform l) = do | |
c <- createSystemRandom | |
i <- uniformR (0, length l - 1) c | |
return (l !! i) | |
interpret (Categorical l) = do | |
c <- createSystemRandom | |
i <- categorical (V.fromList . map snd $ l) c | |
return (fst $ l !! i) | |
interpret (Normal x y) = do | |
c <- createSystemRandom | |
normal x y c | |
interpret (Bernoulli h t x) = do | |
c <- createSystemRandom | |
bool h t <$> bernoulli x c | |
interpret (Beta x y) = do | |
c <- createSystemRandom | |
beta x y c | |
interpret (Gamma x y) = do | |
c <- createSystemRandom | |
gamma x y c |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment