Skip to content

Instantly share code, notes, and snippets.

@jtobin
Created October 27, 2016 00:57
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save jtobin/497e688359c17d1fdf9215868a300b55 to your computer and use it in GitHub Desktop.
Save jtobin/497e688359c17d1fdf9215868a300b55 to your computer and use it in GitHub Desktop.
Probabilistic programming using comonads.
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
import Control.Comonad
import Control.Comonad.Cofree
import Control.Monad
import Control.Monad.ST
import Control.Monad.Free
import Data.Bits
import Data.Dynamic
import Data.Maybe
import Data.Void
import qualified Data.Vector as V
import Data.Word
import qualified System.Random.MWC as MWC
import System.Random.MWC.Probability (Prob)
import qualified System.Random.MWC.Probability as Prob
data ModelF a r =
BernoulliF Double (Bool -> r)
| BetaF Double Double (Double -> r)
| NormalF Double Double (Double -> r)
| DiracF a
deriving Functor
type Program a = Free (ModelF a)
type Model b = forall a. Program a b
type Terminating a = Program a Void
type Execution a = Cofree (ModelF a) Node
data Node = Node {
nodeCost :: Double
, nodeValue :: Dynamic
, nodeSeed :: MWC.Seed
, nodeHistory :: [Dynamic]
} deriving Show
-- primitive terms ------------------------------------------------------------
beta :: Double -> Double -> Program a Double
beta a b = liftF (BetaF a b id)
bernoulli :: Double -> Program a Bool
bernoulli p = liftF (BernoulliF p id)
normal :: Double -> Double -> Program a Double
normal m s = liftF (NormalF m s id)
dirac :: a -> Program a b
dirac x = liftF (DiracF x)
-- sampling -------------------------------------------------------------------
toSampler :: Program a a -> Prob IO a
toSampler = iterM $ \case
BernoulliF p f -> Prob.bernoulli p >>= f
BetaF a b f -> Prob.beta a b >>= f
NormalF m s f -> Prob.normal m s >>= f
DiracF x -> return x
simulate :: Prob IO a -> IO a
simulate model = MWC.withSystemRandom . MWC.asGenIO $ Prob.sample model
-- densities ------------------------------------------------------------------
logDensityBernoulli :: Double -> Bool -> Double
logDensityBernoulli p x
| p < 0 || p > 1 = log 0
| otherwise = b * log p + (1 - b) * log (1 - p)
where
b = if x then 1 else 0
logDensityBeta :: Double -> Double -> Double -> Double
logDensityBeta a b x
| x <= 0 || x >= 1 = log 0
| a < 0 || b < 0 = log 0
| otherwise = (a - 1) * log x + (b - 1) * log (1 - x)
logDensityNormal :: Double -> Double -> Double -> Double
logDensityNormal m s x
| s <= 0 = log 0
| otherwise = negate (log s) - (x - m) ^ 2 / (2 * s ^ 2)
logDensityDirac :: Eq a => a -> a -> Double
logDensityDirac a x
| a == x = 0
| otherwise = negate (1 / 0)
-- execution: initializing ----------------------------------------------------
execute :: Typeable a => Terminating a -> Execution a
execute = executeGeneric (42, 108512)
executeGeneric
:: Typeable a => (Word32, Word32) -> Terminating a -> Execution a
executeGeneric = annotate where
annotate seeds term = case term of
Pure r -> absurd r
Free instruction ->
let (nextSeeds, generator) = xorshift seeds
seed = MWC.toSeed (V.singleton generator)
node = initialize seed instruction
in node :< fmap (annotate nextSeeds) instruction
samplePurely
:: Typeable a => Prob (ST s) a -> Prob.Seed -> ST s (Dynamic, Prob.Seed)
samplePurely prog seed = do
prng <- MWC.restore seed
value <- MWC.asGenST (Prob.sample prog) prng
nodeSeed <- MWC.save prng
if seed == nodeSeed
then error "a generator failed to step!"
else return (toDyn value, nodeSeed)
initialize :: Typeable a => MWC.Seed -> ModelF a b -> Node
initialize seed = \case
BernoulliF p _ -> runST $ do
(nodeValue, nodeSeed) <- samplePurely (Prob.bernoulli p) seed
let nodeCost = logDensityBernoulli p (unsafeFromDyn nodeValue)
nodeHistory = mempty
return Node {..}
BetaF a b _ -> runST $ do
(nodeValue, nodeSeed) <- samplePurely (Prob.beta a b) seed
let nodeCost = logDensityBeta a b (unsafeFromDyn nodeValue)
nodeHistory = mempty
return Node {..}
NormalF m s _ -> runST $ do
(nodeValue, nodeSeed) <- samplePurely (Prob.normal m s) seed
let nodeCost = logDensityNormal m s (unsafeFromDyn nodeValue)
nodeHistory = mempty
return Node {..}
DiracF a -> Node 0 (toDyn a) seed mempty
-- execution: scoring and running ---------------------------------------------
score :: Execution a -> Double
score = loop 0 where
loop !acc (Node {..} :< cons) = case cons of
BernoulliF _ k -> loop (acc + nodeCost) (k (unsafeFromDyn nodeValue))
BetaF _ _ k -> loop (acc + nodeCost) (k (unsafeFromDyn nodeValue))
NormalF _ _ k -> loop (acc + nodeCost) (k (unsafeFromDyn nodeValue))
DiracF _ -> acc
depth :: Execution a -> Int
depth = loop 0 where
loop !acc (Node {..} :< cons) = case cons of
BernoulliF _ k -> loop (succ acc) (k (unsafeFromDyn nodeValue))
BetaF _ _ k -> loop (succ acc) (k (unsafeFromDyn nodeValue))
NormalF _ _ k -> loop (succ acc) (k (unsafeFromDyn nodeValue))
DiracF _ -> succ acc
step :: Typeable a => Execution a -> Execution a
step prog@(Node {..} :< _) = stepWithInput nodeValue prog
stepWithInput :: Typeable a => Dynamic -> Execution a -> Execution a
stepWithInput value prog = case unwrap prog of
BernoulliF _ k -> k (unsafeFromDyn value)
BetaF _ _ k -> k (unsafeFromDyn value)
NormalF _ _ k -> k (unsafeFromDyn value)
DiracF _ -> prog
run :: Typeable a => Execution a -> a
run prog = case unwrap prog of
DiracF a -> a
_ -> run (step prog)
runWithInput :: Typeable a => Dynamic -> Execution a -> a
runWithInput value = run . stepWithInput value
stepGenerators :: Functor f => Cofree f Node -> Cofree f Node
stepGenerators = extend stepGenerator
stepGenerator :: Cofree f Node -> Node
stepGenerator (Node {..} :< cons) = runST $ do
(_, nseed) <- samplePurely (Prob.beta 1 1) nodeSeed
return Node {nodeSeed = nseed, ..}
-- mcmc: perturb --------------------------------------------------------------
perturb :: Execution a -> Execution a
perturb = extend perturbNode
perturbNode :: Execution a -> Node
perturbNode (node@Node {..} :< cons) = case cons of
BernoulliF p _ -> runST $ do
(nvalue, nseed) <- samplePurely (Prob.bernoulli p) nodeSeed
let nscore = logDensityBernoulli p (unsafeFromDyn nvalue)
return $! Node nscore nvalue nseed nodeHistory
BetaF a b _ -> runST $ do
(nvalue, nseed) <- samplePurely (Prob.beta a b) nodeSeed
let nscore = logDensityBeta a b (unsafeFromDyn nvalue)
return $! Node nscore nvalue nseed nodeHistory
NormalF m s _ -> runST $ do
(nvalue, nseed) <- samplePurely (Prob.normal m s) nodeSeed
let nscore = logDensityNormal m s (unsafeFromDyn nvalue)
return $! Node nscore nvalue nseed nodeHistory
DiracF a -> node
-- mcmc: markov chain ---------------------------------------------------------
invert
:: (Eq a, Typeable a, Typeable b)
=> Int -> [a] -> Model b -> (b -> a -> Double)
-> Model (Execution b)
invert epochs obs prior ll = loop epochs (execute (prior >>= dirac)) where
loop n current
| n == 0 = return current
| otherwise = do
let proposal = perturb current
valueAtCurrent = run current
valueAtProposal = run proposal
currentLl = ll valueAtCurrent
proposalLl = ll valueAtProposal
currentContribution = sum (fmap currentLl obs)
proposalContribution = sum (fmap proposalLl obs)
currentScore = score current + currentContribution
proposalScore = score proposal + proposalContribution
fw = negate (log (fromIntegral (depth current))) + score proposal
bw = negate (log (fromIntegral (depth proposal))) + score current
prob = moveProbability currentScore proposalScore bw fw
accept <- bernoulli prob
let next = if accept then proposal else stepGenerators current
loop (pred n) (snapshot next)
moveProbability :: Double -> Double -> Double -> Double -> Double
moveProbability current proposal bw fw =
whenNaN 0 (exp (min 0 (proposal - current + bw - fw)))
where
whenNaN val x
| isNaN x = val
| otherwise = x
-- Record the present value of every node in its history.
snapshot :: Functor f => Cofree f Node -> Cofree f Node
snapshot = extend snapshotValue
snapshotValue :: Cofree f Node -> Node
snapshotValue (Node {..} :< _) = Node { nodeHistory = history, .. } where
history = nodeValue : nodeHistory
-- Data.Bits.Extended ---------------------------------------------------------
-- | A pure xorshift implementation.
--
-- See: https://en.wikipedia.org/wiki/Xorshift.
xorshift :: (Bits t, Num t) => (t, t) -> ((t, t), t)
xorshift (s0, s1) = ((s1, s11), s11 + s1) where
x = s0 `xor` shiftL s0 23
s11 = x `xor` s1 `xor` (shiftR x 17) `xor` (shiftR s1 26)
-- Data.Dynamic.Extended ------------------------------------------------------
unsafeFromDyn :: Typeable a => Dynamic -> a
unsafeFromDyn = fromJust . fromDynamic
-- test / illustration --------------------------------------------------------
posterior1 :: Model (Execution Bool)
posterior1 = invert 1000 obs prior model where
obs = [ -1.7, -1.8, -2.01, -2.4
, 1.9, 1.8
]
prior = do
p <- beta 3 2
bernoulli p
model left
| left = logDensityNormal (negate 2) 0.5
| otherwise = logDensityNormal 2 0.5
mixture :: Double -> Double -> Model Double
mixture a b = do
prob <- beta a b
accept <- bernoulli prob
if accept
then normal (negate 2) 0.5
else normal 2 0.5
trollGeometric :: Double -> Model Int
trollGeometric p = loop where
loop = do
accept <- return False
if accept
then return 1
else fmap succ loop
analysis1 :: IO ()
analysis1 = do
level0@(Node {..} :< _) <- simulate (toSampler posterior1)
writeFile "post_p1_raw.dat" (show (fmap unsafeFromDyn nodeHistory :: [Double]))
let level1@(Node {..} :< _) = step level0
writeFile "post_b1_raw.dat" (show (fmap unsafeFromDyn nodeHistory :: [Bool]))
main :: IO ()
main = analysis1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment