Created September 2, 2024 13:00
Chinese Restaurant Process in an embedded language with recursion schemes
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
import Control.Monad.Free
import qualified Control.Monad.Trans.Free as TF
import qualified Data.Foldable as F
import qualified Data.List as L
import Data.Functor.Foldable
import qualified System.Random.MWC.Probability as MWC
import qualified Data.IntMap.Strict as IMS
-- probabilistic instruction set, program definitions
data ModelF r =
BernoulliF Double (Bool -> r)
| UniformF (Double -> r)
| CategoricalF [Double] (Int -> r)
deriving Functor
type Model = Free ModelF
prob :: Model a -> MWC.Prob IO a
prob = iterM $ \case
BernoulliF p f -> MWC.bernoulli p >>= f
UniformF f -> MWC.uniform >>= f
CategoricalF ps f -> MWC.categorical ps >>= f
uniform :: Free (TF.FreeF ModelF a) Double
uniform = liftF (TF.Free (UniformF id))
categorical :: [Double] -> Free (TF.FreeF ModelF a) Int
categorical ps = liftF (TF.Free (CategoricalF ps id))
-- utilities
fi = fromIntegral
-- classic representation (integer partitions), conditional draw
crp :: Int -> Double -> Model [[Int]]
crp n a
| n <= 0 = Pure mempty
| otherwise = futu coalg (1, IMS.singleton 0 [0])
coalg (l, ts)
| l >= n = TF.Pure (F.toList ts)
| otherwise =
let p = a / (a + fi l)
k = succ l
in TF.Free . BernoulliF p $ \accept ->
if accept
then let s = IMS.size ts
in pure (k, IMS.insert s [l] ts)
else do
res <- seat l ts
pure (k, res)
seat l ts = do
u <- uniform
let ps = fmap ((/ fi l) . fi . length) ts
cps = scanl1 (+) (F.toList ps)
midx = L.findIndex (>= u) cps
idx = case midx of
Just i -> i
_ -> error "seat: impossible"
pure (IMS.adjust ((:) l) idx ts)
-- using a single unconditional draw, ana
crp0 :: Int -> Double -> Model [[Int]]
crp0 n a
| n <= 0 = Pure mempty
| otherwise = ana coalg (1, IMS.singleton 0 [0])
coalg (l, ts)
| l >= n = TF.Pure (F.toList ts)
| otherwise =
let k = succ l
ph = a / (a + fi l)
pt = fmap ((/ (a + fi l)) . fi . length) ts
ps = ph : F.toList pt
in TF.Free (CategoricalF ps (\i ->
if i == 0
then let s = IMS.size ts
in (k, IMS.insert s [l] ts)
else (k, IMS.adjust ((:) l) (i - 1) ts)))
-- using a single unconditional draw, futu
crp1 :: Int -> Double -> Model [[Int]]
crp1 n a
| n <= 0 = Pure mempty
| otherwise = futu coalg (1, IMS.singleton 0 [0])
coalg (l, ts)
| l >= n = TF.Pure (F.toList ts)
| otherwise =
let k = succ l
ph = a / (a + fi l)
pt = fmap ((/ (a + fi l)) . fi . length) ts
ps = ph : F.toList pt
in TF.Free (CategoricalF ps (\i ->
if i == 0
then let s = IMS.size ts
in pure (k, IMS.insert s [l] ts)
else pure (k, IMS.adjust ((:) l) (i - 1) ts)))
-- using two draws, conditional categorical
crp2 :: Int -> Double -> Model [[Int]]
crp2 n a
| n <= 0 = Pure mempty
| otherwise = futu coalg (1, IMS.singleton 0 [0])
coalg (l, ts)
| l >= n = TF.Pure (F.toList ts)
| otherwise =
let k = succ l
p = a / (a + fi l)
in TF.Free (BernoulliF p (\accept ->
if accept
then let s = IMS.size ts
in pure (k, IMS.insert s [l] ts)
else do
let ps = fmap ((/ (fi l)) . fi . length) ts
i <- categorical (F.toList ps)
pure (k, IMS.adjust ((:) l) i ts)
jtobin commented Sep 3, 2024

Here's a one-liner to get a Nix shell with GHCi and the required libraries:

$ nix-shell -p "haskellPackages.ghcWithPackages (pkgs: [pkgs.recursion-schemes pkgs.mwc-probability])"

