Created
September 2, 2024 13:00
-
-
Save jtobin/8da5c8b46297e4868c25082d74bd1ebf to your computer and use it in GitHub Desktop.
Chinese Restaurant Process in an embedded language with recursion schemes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE LambdaCase #-} | |
import Control.Monad.Free | |
import qualified Control.Monad.Trans.Free as TF | |
import qualified Data.Foldable as F | |
import qualified Data.List as L | |
import Data.Functor.Foldable | |
import qualified System.Random.MWC.Probability as MWC | |
import qualified Data.IntMap.Strict as IMS | |
-- probabilistic instruction set, program definitions | |
data ModelF r = | |
BernoulliF Double (Bool -> r) | |
| UniformF (Double -> r) | |
| CategoricalF [Double] (Int -> r) | |
deriving Functor | |
type Model = Free ModelF | |
prob :: Model a -> MWC.Prob IO a | |
prob = iterM $ \case | |
BernoulliF p f -> MWC.bernoulli p >>= f | |
UniformF f -> MWC.uniform >>= f | |
CategoricalF ps f -> MWC.categorical ps >>= f | |
uniform :: Free (TF.FreeF ModelF a) Double | |
uniform = liftF (TF.Free (UniformF id)) | |
categorical :: [Double] -> Free (TF.FreeF ModelF a) Int | |
categorical ps = liftF (TF.Free (CategoricalF ps id)) | |
-- utilities | |
fi = fromIntegral | |
-- classic representation (integer partitions), conditional draw | |
crp :: Int -> Double -> Model [[Int]] | |
crp n a | |
| n <= 0 = Pure mempty | |
| otherwise = futu coalg (1, IMS.singleton 0 [0]) | |
where | |
coalg (l, ts) | |
| l >= n = TF.Pure (F.toList ts) | |
| otherwise = | |
let p = a / (a + fi l) | |
k = succ l | |
in TF.Free . BernoulliF p $ \accept -> | |
if accept | |
then let s = IMS.size ts | |
in pure (k, IMS.insert s [l] ts) | |
else do | |
res <- seat l ts | |
pure (k, res) | |
seat l ts = do | |
u <- uniform | |
let ps = fmap ((/ fi l) . fi . length) ts | |
cps = scanl1 (+) (F.toList ps) | |
midx = L.findIndex (>= u) cps | |
idx = case midx of | |
Just i -> i | |
_ -> error "seat: impossible" | |
pure (IMS.adjust ((:) l) idx ts) | |
-- using a single unconditional draw, ana | |
crp0 :: Int -> Double -> Model [[Int]] | |
crp0 n a | |
| n <= 0 = Pure mempty | |
| otherwise = ana coalg (1, IMS.singleton 0 [0]) | |
where | |
coalg (l, ts) | |
| l >= n = TF.Pure (F.toList ts) | |
| otherwise = | |
let k = succ l | |
ph = a / (a + fi l) | |
pt = fmap ((/ (a + fi l)) . fi . length) ts | |
ps = ph : F.toList pt | |
in TF.Free (CategoricalF ps (\i -> | |
if i == 0 | |
then let s = IMS.size ts | |
in (k, IMS.insert s [l] ts) | |
else (k, IMS.adjust ((:) l) (i - 1) ts))) | |
-- using a single unconditional draw, futu | |
crp1 :: Int -> Double -> Model [[Int]] | |
crp1 n a | |
| n <= 0 = Pure mempty | |
| otherwise = futu coalg (1, IMS.singleton 0 [0]) | |
where | |
coalg (l, ts) | |
| l >= n = TF.Pure (F.toList ts) | |
| otherwise = | |
let k = succ l | |
ph = a / (a + fi l) | |
pt = fmap ((/ (a + fi l)) . fi . length) ts | |
ps = ph : F.toList pt | |
in TF.Free (CategoricalF ps (\i -> | |
if i == 0 | |
then let s = IMS.size ts | |
in pure (k, IMS.insert s [l] ts) | |
else pure (k, IMS.adjust ((:) l) (i - 1) ts))) | |
-- using two draws, conditional categorical | |
crp2 :: Int -> Double -> Model [[Int]] | |
crp2 n a | |
| n <= 0 = Pure mempty | |
| otherwise = futu coalg (1, IMS.singleton 0 [0]) | |
where | |
coalg (l, ts) | |
| l >= n = TF.Pure (F.toList ts) | |
| otherwise = | |
let k = succ l | |
p = a / (a + fi l) | |
in TF.Free (BernoulliF p (\accept -> | |
if accept | |
then let s = IMS.size ts | |
in pure (k, IMS.insert s [l] ts) | |
else do | |
let ps = fmap ((/ (fi l)) . fi . length) ts | |
i <- categorical (F.toList ps) | |
pure (k, IMS.adjust ((:) l) i ts) | |
)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here's a one-liner to get a Nix shell with GHCi and the required libraries: