Chinese Restaurant Process in an embedded language with recursion schemes
 {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE LambdaCase #-} import Control.Monad.Free import qualified Control.Monad.Trans.Free as TF import qualified Data.Foldable as F import qualified Data.List as L import Data.Functor.Foldable import qualified System.Random.MWC.Probability as MWC import qualified Data.IntMap.Strict as IMS -- probabilistic instruction set, program definitions data ModelF r = BernoulliF Double (Bool -> r) | UniformF (Double -> r) | CategoricalF [Double] (Int -> r) deriving Functor type Model = Free ModelF prob :: Model a -> MWC.Prob IO a prob = iterM \$ \case BernoulliF p f -> MWC.bernoulli p >>= f UniformF f -> MWC.uniform >>= f CategoricalF ps f -> MWC.categorical ps >>= f uniform :: Free (TF.FreeF ModelF a) Double uniform = liftF (TF.Free (UniformF id)) categorical :: [Double] -> Free (TF.FreeF ModelF a) Int categorical ps = liftF (TF.Free (CategoricalF ps id)) -- utilities fi = fromIntegral -- classic representation (integer partitions), conditional draw crp :: Int -> Double -> Model [[Int]] crp n a | n <= 0 = Pure mempty | otherwise = futu coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let p = a / (a + fi l) k = succ l in TF.Free . BernoulliF p \$ \accept -> if accept then let s = IMS.size ts in pure (k, IMS.insert s [l] ts) else do res <- seat l ts pure (k, res) seat l ts = do u <- uniform let ps = fmap ((/ fi l) . fi . length) ts cps = scanl1 (+) (F.toList ps) midx = L.findIndex (>= u) cps idx = case midx of Just i -> i _ -> error "seat: impossible" pure (IMS.adjust ((:) l) idx ts) -- using a single unconditional draw, ana crp0 :: Int -> Double -> Model [[Int]] crp0 n a | n <= 0 = Pure mempty | otherwise = ana coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let k = succ l ph = a / (a + fi l) pt = fmap ((/ (a + fi l)) . fi . length) ts ps = ph : F.toList pt in TF.Free (CategoricalF ps (\i -> if i == 0 then let s = IMS.size ts in (k, IMS.insert s [l] ts) else (k, IMS.adjust ((:) l) (i - 1) ts))) -- using a single unconditional draw, futu crp1 :: Int -> Double -> Model [[Int]] crp1 n a | n <= 0 = Pure mempty | otherwise = futu coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let k = succ l ph = a / (a + fi l) pt = fmap ((/ (a + fi l)) . fi . length) ts ps = ph : F.toList pt in TF.Free (CategoricalF ps (\i -> if i == 0 then let s = IMS.size ts in pure (k, IMS.insert s [l] ts) else pure (k, IMS.adjust ((:) l) (i - 1) ts))) -- using two draws, conditional categorical crp2 :: Int -> Double -> Model [[Int]] crp2 n a | n <= 0 = Pure mempty | otherwise = futu coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let k = succ l p = a / (a + fi l) in TF.Free (BernoulliF p (\accept -> if accept then let s = IMS.size ts in pure (k, IMS.insert s [l] ts) else do let ps = fmap ((/ (fi l)) . fi . length) ts i <- categorical (F.toList ps) pure (k, IMS.adjust ((:) l) i ts) ))

jtobin commented Sep 3, 2024

Here's a one-liner to get a Nix shell with GHCi and the required libraries:

``````\$ nix-shell -p "haskellPackages.ghcWithPackages (pkgs: [pkgs.recursion-schemes pkgs.mwc-probability])"
