Created
October 18, 2016 08:52
-
-
Save jtobin/95573e26843cf5fa0295360d3b33d3f1 to your computer and use it in GitHub Desktop.
A simple embedded probabilistic programming language
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE LambdaCase #-} | |
import Control.Monad | |
import Control.Monad.Free | |
import qualified System.Random.MWC.Probability as MWC | |
data ModelF r = | |
BernoulliF Double (Bool -> r) | |
| BetaF Double Double (Double -> r) | |
deriving Functor | |
type Model = Free ModelF | |
-- terms | |
bernoulli :: Double -> Model Bool | |
bernoulli p = liftF (BernoulliF p id) | |
beta :: Double -> Double -> Model Double | |
beta a b = liftF (BetaF a b id) | |
binomial :: Int -> Double -> Model Int | |
binomial n p = fmap count coins where | |
coins = replicateM n (bernoulli p) | |
uniform :: Model Double | |
uniform = beta 1 1 | |
betaBinomial :: Int -> Double -> Double -> Model Int | |
betaBinomial n a b = do | |
p <- beta a b | |
binomial n p | |
-- simulation | |
toSampler :: Model a -> MWC.Prob IO a | |
toSampler = iterM $ \case | |
BernoulliF p f -> MWC.bernoulli p >>= f | |
BetaF a b f -> MWC.beta a b >>= f | |
simulate :: Model a -> IO a | |
simulate model = MWC.withSystemRandom . MWC.asGenIO $ | |
MWC.sample (toSampler model) | |
-- conditioning | |
invert :: (Monad m, Eq b) => m a -> (a -> m b) -> [b] -> m a | |
invert proposal model observed = loop where | |
len = length observed | |
loop = do | |
parameters <- proposal | |
generated <- replicateM len (model parameters) | |
if generated == observed | |
then return parameters | |
else loop | |
invertWithAssistance | |
:: (Monad m, Eq c) => ([a] -> c) -> m b -> (b -> m a) -> [a] -> m b | |
invertWithAssistance assister proposal model observed = loop where | |
len = length observed | |
loop = do | |
parameters <- proposal | |
generated <- replicateM len (model parameters) | |
if assister generated == assister observed | |
then return parameters | |
else loop | |
-- example | |
posterior :: Model Double | |
posterior = invert uniform bernoulli [True, True, False, True] | |
posterior0 :: Model Double | |
posterior0 = invertWithAssistance count uniform bernoulli obs where | |
obs = | |
[True, True, True, False, True, True, False, True, True, True, True, False] | |
count :: [Bool] -> Int | |
count = length . filter id | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment