Skip to content

Instantly share code, notes, and snippets.

@jtobin
Created October 28, 2016 07:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jtobin/4312880f9fa7b63ebbe5b84b9aa60ff5 to your computer and use it in GitHub Desktop.
Save jtobin/4312880f9fa7b63ebbe5b84b9aa60ff5 to your computer and use it in GitHub Desktop.
Tweaking comonadic inference for performance
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
import Control.Comonad
import Control.Comonad.Cofree
import qualified Control.Foldl as L
import Control.Monad.Free
import Data.Bits
import Data.Random
import qualified Data.Random.Distribution.Bernoulli as RF
import qualified Data.Random.Distribution.Beta as RF
import qualified Data.Random.Distribution.Normal as RF
import Data.Void
import Data.Word
import System.Random.Mersenne.Pure64
-- language types -------------------------------------------------------------
data ModelF a r =
BernoulliF {-# UNPACK #-} !Double (Bool -> r)
| BetaF {-# UNPACK #-} !Double {-# UNPACK #-} !Double (Double -> r)
| NormalF {-# UNPACK #-} !Double {-# UNPACK #-} !Double (Double -> r)
| DiracF !a
deriving Functor
type Program a = Free (ModelF a)
type Model b = forall a. Program a b
type Terminating a = Program a Void
type Execution a = Cofree (ModelF a) Node
-- boring mechanical types ----------------------------------------------------
data Value =
VBool !Bool
| VDouble {-# UNPACK #-} !Double
| VEmpty
deriving (Eq, Show)
data Node = Node {
nodeCost :: {-# UNPACK #-} !Double
, nodeValue :: !Value
, nodePrng :: !PureMT
} deriving Show
data Seed = Seed {-# UNPACK #-} !Word64 {-# UNPACK #-} !Word64
defaultSeed :: Seed
defaultSeed = Seed 42 108512
-- primitive terms ------------------------------------------------------------
beta :: Double -> Double -> Program a Double
beta a b
| a < 0 || b < 0 = error "out of bounds"
| otherwise = liftF (BetaF a b id)
bernoulli :: Double -> Program a Bool
bernoulli p = liftF (BernoulliF vp id) where
vp
| p < 0 = 0
| p > 1 = 1
| otherwise = p
normal :: Double -> Double -> Program a Double
normal m s
| s < 0 = error "negative variance"
| otherwise = liftF (NormalF m s id)
dirac :: a -> Program a b
dirac x = liftF (DiracF x)
-- densities ------------------------------------------------------------------
logDensityBernoulli :: Double -> Bool -> Double
logDensityBernoulli p x
| p < 0 || p > 1 = log 0
| otherwise = b * log p + (1 - b) * log (1 - p)
where
b = if x then 1 else 0
logDensityBeta :: Double -> Double -> Double -> Double
logDensityBeta a b x
| x <= 0 || x >= 1 = log 0
| a < 0 || b < 0 = log 0
| otherwise = (a - 1) * log x + (b - 1) * log (1 - x)
logDensityNormal :: Double -> Double -> Double -> Double
logDensityNormal m s x
| s <= 0 = log 0
| otherwise = negate (log s) - (x - m) ^ 2 / (2 * s ^ 2)
logDensityDirac :: Eq a => a -> a -> Double
logDensityDirac a x
| a == x = 0
| otherwise = log 0
-- sampling -------------------------------------------------------------------
toSampler :: Program a a -> RVar a
toSampler = iterM $ \case
BernoulliF p f -> RF.bernoulli p >>= f
BetaF a b f -> RF.beta a b >>= f
NormalF m s f -> RF.normal m s >>= f
DiracF x -> return x
-- execution: initializing ----------------------------------------------------
execute :: Terminating a -> Execution a
execute = loop defaultSeed where
loop seed term = case term of
Pure r -> absurd r
Free instruction ->
let (nseed, gseed) = xorshift seed
node = initialize (pureMT gseed) instruction
in node :< fmap (loop nseed) instruction
initialize :: PureMT -> ModelF a b -> Node
initialize prng = \case
BernoulliF p _ -> Node {..} where
(nvalue, nodePrng) = sampleState (RF.bernoulli p) prng
nodeCost = logDensityBernoulli p nvalue
nodeValue = VBool nvalue
BetaF a b _ -> Node {..} where
(nvalue, nodePrng) = sampleState (RF.beta a b) prng
nodeCost = logDensityBeta a b nvalue
nodeValue = VDouble nvalue
NormalF m s _ -> Node {..} where
(nvalue, nodePrng) = sampleState (RF.normal m s) prng
nodeCost = logDensityNormal m s nvalue
nodeValue = VDouble nvalue
DiracF _ -> Node {..} where
nodeCost = 0
nodeValue = VEmpty
nodePrng = prng
-- execution: scoring and running ---------------------------------------------
score :: Execution a -> Double
score = loop 0 where
loop !acc (Node {..} :< cons) = case cons of
BernoulliF _ k ->
let VBool val = nodeValue
in loop (acc + nodeCost) (k val)
BetaF _ _ k ->
let VDouble val = nodeValue
in loop (acc + nodeCost) (k val)
NormalF _ _ k ->
let VDouble val = nodeValue
in loop (acc + nodeCost) (k val)
DiracF _ -> acc
depth :: Execution a -> Int
depth = loop 0 where
loop !acc (Node {..} :< cons) = case cons of
BernoulliF _ k ->
let VBool val = nodeValue
in loop (succ acc) (k val)
BetaF _ _ k ->
let VDouble val = nodeValue
in loop (succ acc) (k val)
NormalF _ _ k ->
let VDouble val = nodeValue
in loop (succ acc) (k val)
DiracF _ -> succ acc
step :: Execution a -> Execution a
step prog@(Node {..} :< _) = stepWithInput nodeValue prog
stepWithInput :: Value -> Execution a -> Execution a
stepWithInput value prog = case unwrap prog of
BernoulliF _ k ->
let VBool val = value
in k val
BetaF _ _ k ->
let VDouble val = value
in k val
NormalF _ _ k ->
let VDouble val = value
in k val
DiracF _ -> prog
run :: Execution a -> a
run prog = case unwrap prog of
DiracF a -> a
_ -> run (step prog)
runWithInput :: Value -> Execution a -> a
runWithInput value = run . stepWithInput value
stepGenerators :: Functor f => Cofree f Node -> Cofree f Node
stepGenerators = extend stepGenerator
stepGenerator :: Cofree f Node -> Node
stepGenerator (Node {..} :< _) = Node {nodePrng = prng, ..} where
(_, prng) = randomInt nodePrng
-- mcmc: perturb --------------------------------------------------------------
perturb :: Execution a -> Execution a
perturb = extend perturbNode
perturbNode :: Execution a -> Node
perturbNode (node@Node {..} :< cons) = case cons of
BernoulliF p _ -> Node ncost nvalue prng where
(val, prng) = sampleState (RF.bernoulli p) nodePrng
ncost = logDensityBernoulli p val
nvalue = VBool val
BetaF a b _ -> Node ncost nvalue prng where
(val, prng) = sampleState (RF.beta a b) nodePrng
ncost = logDensityBeta a b val
nvalue = VDouble val
NormalF m s _ -> Node ncost nvalue prng where
(val, prng) = sampleState (RF.normal m s) nodePrng
ncost = logDensityNormal m s val
nvalue = VDouble val
DiracF _ -> node
-- mcmc: markov chain ---------------------------------------------------------
invert :: Int -> [a] -> Model b -> (b -> a -> Double) -> Model (Execution b)
invert epochs obs prior ll = loop epochs (execute (prior >>= dirac)) where
loop n current
| n == 0 = return current
| otherwise = do
let proposal = perturb current
ccostPrior = score current
pcostPrior = score proposal
cvalue = run current
pvalue = run proposal
ccostObs = L.fold (L.premap (ll cvalue) L.sum) obs
pcostObs = L.fold (L.premap (ll pvalue) L.sum) obs
ccost = ccostPrior + ccostObs
pcost = pcostPrior + pcostObs
fwcost = negate (log (fromIntegral (depth current))) + pcostPrior
bwcost = negate (log (fromIntegral (depth proposal))) + ccostPrior
prob = moveProbability ccost pcost bwcost fwcost
accept <- bernoulli prob
let next = if accept then proposal else stepGenerators current
loop (pred n) next
moveProbability :: Double -> Double -> Double -> Double -> Double
moveProbability current proposal bw fw =
whenNaN 0 (exp (min 0 (proposal - current + bw - fw)))
where
whenNaN val x
| isNaN x = val
| otherwise = x
-- Data.Bits.Extended ---------------------------------------------------------
-- | A pure xorshift implementation.
--
-- See: https://en.wikipedia.org/wiki/Xorshift.
xorshift :: Seed -> (Seed, Word64)
xorshift (Seed s0 s1) = (Seed s1 s11, s11 + s1) where
x = s0 `xor` shiftL s0 23
s11 = x `xor` s1 `xor` (shiftR x 17) `xor` (shiftR s1 26)
-- test -----------------------------------------------------------------------
test :: Model Bool
test = do
p <- beta 1 2
bernoulli p
xs :: [Double]
xs = [ -1.7, -1.8, -2.01, -2.4
, 1.9, 1.8
]
model :: Bool -> Double -> Double
model left
| left = logDensityNormal (negate 2) 0.5
| otherwise = logDensityNormal 2 0.5
posterior :: Model (Execution Bool)
posterior = invert 1000 xs test model
main :: IO ()
main = do
foo <- sample (toSampler posterior)
print (extract foo)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment