Instantly share code, notes, and snippets.

jtobin/CRP.hs

Created September 2, 2024 13:00
Show Gist options
• Save jtobin/8da5c8b46297e4868c25082d74bd1ebf to your computer and use it in GitHub Desktop.
Chinese Restaurant Process in an embedded language with recursion schemes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE LambdaCase #-} import Control.Monad.Free import qualified Control.Monad.Trans.Free as TF import qualified Data.Foldable as F import qualified Data.List as L import Data.Functor.Foldable import qualified System.Random.MWC.Probability as MWC import qualified Data.IntMap.Strict as IMS -- probabilistic instruction set, program definitions data ModelF r = BernoulliF Double (Bool -> r) | UniformF (Double -> r) | CategoricalF [Double] (Int -> r) deriving Functor type Model = Free ModelF prob :: Model a -> MWC.Prob IO a prob = iterM \$ \case BernoulliF p f -> MWC.bernoulli p >>= f UniformF f -> MWC.uniform >>= f CategoricalF ps f -> MWC.categorical ps >>= f uniform :: Free (TF.FreeF ModelF a) Double uniform = liftF (TF.Free (UniformF id)) categorical :: [Double] -> Free (TF.FreeF ModelF a) Int categorical ps = liftF (TF.Free (CategoricalF ps id)) -- utilities fi = fromIntegral -- classic representation (integer partitions), conditional draw crp :: Int -> Double -> Model [[Int]] crp n a | n <= 0 = Pure mempty | otherwise = futu coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let p = a / (a + fi l) k = succ l in TF.Free . BernoulliF p \$ \accept -> if accept then let s = IMS.size ts in pure (k, IMS.insert s [l] ts) else do res <- seat l ts pure (k, res) seat l ts = do u <- uniform let ps = fmap ((/ fi l) . fi . length) ts cps = scanl1 (+) (F.toList ps) midx = L.findIndex (>= u) cps idx = case midx of Just i -> i _ -> error "seat: impossible" pure (IMS.adjust ((:) l) idx ts) -- using a single unconditional draw, ana crp0 :: Int -> Double -> Model [[Int]] crp0 n a | n <= 0 = Pure mempty | otherwise = ana coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let k = succ l ph = a / (a + fi l) pt = fmap ((/ (a + fi l)) . fi . length) ts ps = ph : F.toList pt in TF.Free (CategoricalF ps (\i -> if i == 0 then let s = IMS.size ts in (k, IMS.insert s [l] ts) else (k, IMS.adjust ((:) l) (i - 1) ts))) -- using a single unconditional draw, futu crp1 :: Int -> Double -> Model [[Int]] crp1 n a | n <= 0 = Pure mempty | otherwise = futu coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let k = succ l ph = a / (a + fi l) pt = fmap ((/ (a + fi l)) . fi . length) ts ps = ph : F.toList pt in TF.Free (CategoricalF ps (\i -> if i == 0 then let s = IMS.size ts in pure (k, IMS.insert s [l] ts) else pure (k, IMS.adjust ((:) l) (i - 1) ts))) -- using two draws, conditional categorical crp2 :: Int -> Double -> Model [[Int]] crp2 n a | n <= 0 = Pure mempty | otherwise = futu coalg (1, IMS.singleton 0 [0]) where coalg (l, ts) | l >= n = TF.Pure (F.toList ts) | otherwise = let k = succ l p = a / (a + fi l) in TF.Free (BernoulliF p (\accept -> if accept then let s = IMS.size ts in pure (k, IMS.insert s [l] ts) else do let ps = fmap ((/ (fi l)) . fi . length) ts i <- categorical (F.toList ps) pure (k, IMS.adjust ((:) l) i ts) ))

jtobin commented Sep 3, 2024

Here's a one-liner to get a Nix shell with GHCi and the required libraries:

``````\$ nix-shell -p "haskellPackages.ghcWithPackages (pkgs: [pkgs.recursion-schemes pkgs.mwc-probability])"
``````