Skip to content

Instantly share code, notes, and snippets.

@jtobin
Created October 18, 2016 08:52
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jtobin/95573e26843cf5fa0295360d3b33d3f1 to your computer and use it in GitHub Desktop.
Save jtobin/95573e26843cf5fa0295360d3b33d3f1 to your computer and use it in GitHub Desktop.
A simple embedded probabilistic programming language
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
import Control.Monad
import Control.Monad.Free
import qualified System.Random.MWC.Probability as MWC
data ModelF r =
BernoulliF Double (Bool -> r)
| BetaF Double Double (Double -> r)
deriving Functor
type Model = Free ModelF
-- terms
bernoulli :: Double -> Model Bool
bernoulli p = liftF (BernoulliF p id)
beta :: Double -> Double -> Model Double
beta a b = liftF (BetaF a b id)
binomial :: Int -> Double -> Model Int
binomial n p = fmap count coins where
coins = replicateM n (bernoulli p)
uniform :: Model Double
uniform = beta 1 1
betaBinomial :: Int -> Double -> Double -> Model Int
betaBinomial n a b = do
p <- beta a b
binomial n p
-- simulation
toSampler :: Model a -> MWC.Prob IO a
toSampler = iterM $ \case
BernoulliF p f -> MWC.bernoulli p >>= f
BetaF a b f -> MWC.beta a b >>= f
simulate :: Model a -> IO a
simulate model = MWC.withSystemRandom . MWC.asGenIO $
MWC.sample (toSampler model)
-- conditioning
invert :: (Monad m, Eq b) => m a -> (a -> m b) -> [b] -> m a
invert proposal model observed = loop where
len = length observed
loop = do
parameters <- proposal
generated <- replicateM len (model parameters)
if generated == observed
then return parameters
else loop
invertWithAssistance
:: (Monad m, Eq c) => ([a] -> c) -> m b -> (b -> m a) -> [a] -> m b
invertWithAssistance assister proposal model observed = loop where
len = length observed
loop = do
parameters <- proposal
generated <- replicateM len (model parameters)
if assister generated == assister observed
then return parameters
else loop
-- example
posterior :: Model Double
posterior = invert uniform bernoulli [True, True, False, True]
posterior0 :: Model Double
posterior0 = invertWithAssistance count uniform bernoulli obs where
obs =
[True, True, True, False, True, True, False, True, True, True, True, False]
count :: [Bool] -> Int
count = length . filter id
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment