Skip to content

Instantly share code, notes, and snippets.

@jtobin jtobin/optimizing.hs
Created Oct 28, 2016

Embed
What would you like to do?
Tweaking comonadic inference for performance
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
import Control.Comonad
import Control.Comonad.Cofree
import qualified Control.Foldl as L
import Control.Monad.Free
import Data.Bits
import Data.Random
import qualified Data.Random.Distribution.Bernoulli as RF
import qualified Data.Random.Distribution.Beta as RF
import qualified Data.Random.Distribution.Normal as RF
import Data.Void
import Data.Word
import System.Random.Mersenne.Pure64
-- language types -------------------------------------------------------------
data ModelF a r =
BernoulliF {-# UNPACK #-} !Double (Bool -> r)
| BetaF {-# UNPACK #-} !Double {-# UNPACK #-} !Double (Double -> r)
| NormalF {-# UNPACK #-} !Double {-# UNPACK #-} !Double (Double -> r)
| DiracF !a
deriving Functor
type Program a = Free (ModelF a)
type Model b = forall a. Program a b
type Terminating a = Program a Void
type Execution a = Cofree (ModelF a) Node
-- boring mechanical types ----------------------------------------------------
data Value =
VBool !Bool
| VDouble {-# UNPACK #-} !Double
| VEmpty
deriving (Eq, Show)
data Node = Node {
nodeCost :: {-# UNPACK #-} !Double
, nodeValue :: !Value
, nodePrng :: !PureMT
} deriving Show
data Seed = Seed {-# UNPACK #-} !Word64 {-# UNPACK #-} !Word64
defaultSeed :: Seed
defaultSeed = Seed 42 108512
-- primitive terms ------------------------------------------------------------
beta :: Double -> Double -> Program a Double
beta a b
| a < 0 || b < 0 = error "out of bounds"
| otherwise = liftF (BetaF a b id)
bernoulli :: Double -> Program a Bool
bernoulli p = liftF (BernoulliF vp id) where
vp
| p < 0 = 0
| p > 1 = 1
| otherwise = p
normal :: Double -> Double -> Program a Double
normal m s
| s < 0 = error "negative variance"
| otherwise = liftF (NormalF m s id)
dirac :: a -> Program a b
dirac x = liftF (DiracF x)
-- densities ------------------------------------------------------------------
logDensityBernoulli :: Double -> Bool -> Double
logDensityBernoulli p x
| p < 0 || p > 1 = log 0
| otherwise = b * log p + (1 - b) * log (1 - p)
where
b = if x then 1 else 0
logDensityBeta :: Double -> Double -> Double -> Double
logDensityBeta a b x
| x <= 0 || x >= 1 = log 0
| a < 0 || b < 0 = log 0
| otherwise = (a - 1) * log x + (b - 1) * log (1 - x)
logDensityNormal :: Double -> Double -> Double -> Double
logDensityNormal m s x
| s <= 0 = log 0
| otherwise = negate (log s) - (x - m) ^ 2 / (2 * s ^ 2)
logDensityDirac :: Eq a => a -> a -> Double
logDensityDirac a x
| a == x = 0
| otherwise = log 0
-- sampling -------------------------------------------------------------------
toSampler :: Program a a -> RVar a
toSampler = iterM $ \case
BernoulliF p f -> RF.bernoulli p >>= f
BetaF a b f -> RF.beta a b >>= f
NormalF m s f -> RF.normal m s >>= f
DiracF x -> return x
-- execution: initializing ----------------------------------------------------
execute :: Terminating a -> Execution a
execute = loop defaultSeed where
loop seed term = case term of
Pure r -> absurd r
Free instruction ->
let (nseed, gseed) = xorshift seed
node = initialize (pureMT gseed) instruction
in node :< fmap (loop nseed) instruction
initialize :: PureMT -> ModelF a b -> Node
initialize prng = \case
BernoulliF p _ -> Node {..} where
(nvalue, nodePrng) = sampleState (RF.bernoulli p) prng
nodeCost = logDensityBernoulli p nvalue
nodeValue = VBool nvalue
BetaF a b _ -> Node {..} where
(nvalue, nodePrng) = sampleState (RF.beta a b) prng
nodeCost = logDensityBeta a b nvalue
nodeValue = VDouble nvalue
NormalF m s _ -> Node {..} where
(nvalue, nodePrng) = sampleState (RF.normal m s) prng
nodeCost = logDensityNormal m s nvalue
nodeValue = VDouble nvalue
DiracF _ -> Node {..} where
nodeCost = 0
nodeValue = VEmpty
nodePrng = prng
-- execution: scoring and running ---------------------------------------------
score :: Execution a -> Double
score = loop 0 where
loop !acc (Node {..} :< cons) = case cons of
BernoulliF _ k ->
let VBool val = nodeValue
in loop (acc + nodeCost) (k val)
BetaF _ _ k ->
let VDouble val = nodeValue
in loop (acc + nodeCost) (k val)
NormalF _ _ k ->
let VDouble val = nodeValue
in loop (acc + nodeCost) (k val)
DiracF _ -> acc
depth :: Execution a -> Int
depth = loop 0 where
loop !acc (Node {..} :< cons) = case cons of
BernoulliF _ k ->
let VBool val = nodeValue
in loop (succ acc) (k val)
BetaF _ _ k ->
let VDouble val = nodeValue
in loop (succ acc) (k val)
NormalF _ _ k ->
let VDouble val = nodeValue
in loop (succ acc) (k val)
DiracF _ -> succ acc
step :: Execution a -> Execution a
step prog@(Node {..} :< _) = stepWithInput nodeValue prog
stepWithInput :: Value -> Execution a -> Execution a
stepWithInput value prog = case unwrap prog of
BernoulliF _ k ->
let VBool val = value
in k val
BetaF _ _ k ->
let VDouble val = value
in k val
NormalF _ _ k ->
let VDouble val = value
in k val
DiracF _ -> prog
run :: Execution a -> a
run prog = case unwrap prog of
DiracF a -> a
_ -> run (step prog)
runWithInput :: Value -> Execution a -> a
runWithInput value = run . stepWithInput value
stepGenerators :: Functor f => Cofree f Node -> Cofree f Node
stepGenerators = extend stepGenerator
stepGenerator :: Cofree f Node -> Node
stepGenerator (Node {..} :< _) = Node {nodePrng = prng, ..} where
(_, prng) = randomInt nodePrng
-- mcmc: perturb --------------------------------------------------------------
perturb :: Execution a -> Execution a
perturb = extend perturbNode
perturbNode :: Execution a -> Node
perturbNode (node@Node {..} :< cons) = case cons of
BernoulliF p _ -> Node ncost nvalue prng where
(val, prng) = sampleState (RF.bernoulli p) nodePrng
ncost = logDensityBernoulli p val
nvalue = VBool val
BetaF a b _ -> Node ncost nvalue prng where
(val, prng) = sampleState (RF.beta a b) nodePrng
ncost = logDensityBeta a b val
nvalue = VDouble val
NormalF m s _ -> Node ncost nvalue prng where
(val, prng) = sampleState (RF.normal m s) nodePrng
ncost = logDensityNormal m s val
nvalue = VDouble val
DiracF _ -> node
-- mcmc: markov chain ---------------------------------------------------------
invert :: Int -> [a] -> Model b -> (b -> a -> Double) -> Model (Execution b)
invert epochs obs prior ll = loop epochs (execute (prior >>= dirac)) where
loop n current
| n == 0 = return current
| otherwise = do
let proposal = perturb current
ccostPrior = score current
pcostPrior = score proposal
cvalue = run current
pvalue = run proposal
ccostObs = L.fold (L.premap (ll cvalue) L.sum) obs
pcostObs = L.fold (L.premap (ll pvalue) L.sum) obs
ccost = ccostPrior + ccostObs
pcost = pcostPrior + pcostObs
fwcost = negate (log (fromIntegral (depth current))) + pcostPrior
bwcost = negate (log (fromIntegral (depth proposal))) + ccostPrior
prob = moveProbability ccost pcost bwcost fwcost
accept <- bernoulli prob
let next = if accept then proposal else stepGenerators current
loop (pred n) next
moveProbability :: Double -> Double -> Double -> Double -> Double
moveProbability current proposal bw fw =
whenNaN 0 (exp (min 0 (proposal - current + bw - fw)))
where
whenNaN val x
| isNaN x = val
| otherwise = x
-- Data.Bits.Extended ---------------------------------------------------------
-- | A pure xorshift implementation.
--
-- See: https://en.wikipedia.org/wiki/Xorshift.
xorshift :: Seed -> (Seed, Word64)
xorshift (Seed s0 s1) = (Seed s1 s11, s11 + s1) where
x = s0 `xor` shiftL s0 23
s11 = x `xor` s1 `xor` (shiftR x 17) `xor` (shiftR s1 26)
-- test -----------------------------------------------------------------------
test :: Model Bool
test = do
p <- beta 1 2
bernoulli p
xs :: [Double]
xs = [ -1.7, -1.8, -2.01, -2.4
, 1.9, 1.8
]
model :: Bool -> Double -> Double
model left
| left = logDensityNormal (negate 2) 0.5
| otherwise = logDensityNormal 2 0.5
posterior :: Model (Execution Bool)
posterior = invert 1000 xs test model
main :: IO ()
main = do
foo <- sample (toSampler posterior)
print (extract foo)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.