Skip to content

Instantly share code, notes, and snippets.

@madsbuch
Last active February 23, 2021 17:43
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save madsbuch/5a8a1fc9b70621dd93dd70058754b126 to your computer and use it in GitHub Desktop.
Save madsbuch/5a8a1fc9b70621dd93dd70058754b126 to your computer and use it in GitHub Desktop.
{-# LANGUAGE GADTs #-}
{-
The following code is based on experimental code by Aslan Askerov
based on Ramsey and Pfeffers "Stochastic Lambda Calculus and Monads of
Probability Distributions". Implementation of random n is from
Audebaud and Paulin-Mohring paper, so is the random walk example.
This gist is used here http://madsbuch.com/the-probability-monad/
The class hierarchy is as follows:
+----------------------------------------+
| |
| Monad |
| |
+--------------------+-------------------+
|
V
+-----------------------------------------+
| |
| Probability Monad |
| |
+------+--------------+--------------+----+
| | |
V V V
+-------------+ +---------+ +-----------+
| | | | | |
| Expectation | | Support | | Sample |
| Monad | | Monad | | Monad |
| | | | | |
+-------------+ +---------+ +-----------+
Make sure to have have the cabal random package
installed: `cabal install random`
-}
module ProbMonad where
import Data.List
import qualified System.Random as R
import Control.Applicative -- Otherwise you can't do the Applicative instance.
import Control.Monad (liftM, ap)
type Probability = Double -- number from 0 to 1
-- Probability Monad
class Monad m => ProbabilityMonad m where
choose :: Probability -> m a -> m a -> m a
-- Support Monad
class ProbabilityMonad m => SupportMonad m where
support :: m a -> [a]
-- Expectation Monad
class ProbabilityMonad m => ExpMonad m where
expectation :: (a -> Double) -> m a -> Double
class ProbabilityMonad m => SamplingMonad m where
sample :: R.RandomGen g => m a -> g -> (a, g)
-- Probability Monad Type
newtype PExp a = PExp (( a -> Double) -> Double)
-- PExp needs to be a functor to be a monad
instance Functor PExp where
fmap = liftM
-- PExp needs to be an applicative to be a monad
instance Applicative PExp where
pure = return
(<*>) = ap
-- PExp is a monad
instance Monad PExp where
return x = PExp (\h -> h x)
(PExp d) >>= k =
PExp (\h -> let
apply (PExp f) arg = f arg
g x = apply (k x) h
in
d g )
-- PExp is a probability monad
instance ProbabilityMonad PExp where
choose p (PExp d1) (PExp d2) =
PExp (\h -> p * d1 h + (1 - p) * d2 h)
-- Not easily implemented
instance SupportMonad PExp where
support (PExp h) = undefined
-- Easily implemented!
instance ExpMonad PExp where
expectation h (PExp d) = d h
-- Not easily implemented
instance SamplingMonad PExp where
sample = undefined
{-- The general probability monad --}
data P a where
R :: a -> P a
B :: P a -> (a -> P b) -> P b -- The reason for GADT
C :: Probability -> P a -> P a -> P a
-- P needs to be a functor to be a monad
instance Functor P where
fmap = liftM
-- P needs to be an applicative to be a monad
instance Applicative P where
pure = return
(<*>) = ap
-- P is a monad
instance Monad P where
return x = R x
d >>= k = B d k
-- P is a probability monad
instance ProbabilityMonad P where
choose p d1 d2 = C p d1 d2
instance SupportMonad P where
support (R x) = [x]
support (B d k) = concat [support (k x) | x <- support d]
support (C p d1 d2) = support d1 ++ support d2
instance ExpMonad P where
expectation h (R x) = h x
expectation h (B d k) = expectation g d
where
g x = expectation h (k x)
expectation h (C p d1 d2) =
(p * expectation h d1)
+ ((1-p) * expectation h d2)
instance SamplingMonad P where
sample (R x) g = (x, g)
sample (B d k) g = let
(x, g') = sample d g
in
sample (k x) g'
sample (C p d1 d2) g = let
(x, g') = R.random g
in
sample (if x < p then d1 else d2) g'
{-- Helper functions --}
prob :: Bool -> Probability
prob b = if b then 1 else 0
uniform :: [a] -> P a
uniform [x] = return x
uniform ls@(x:xs) =
let p = 1.0 / ( fromIntegral (length ls) )
in choose p (return x) (uniform xs)
-- taking samples
nSamples :: R.RandomGen g => Int -> P a -> g -> [(a, g)]
nSamples 0 dist gen = []
nSamples n dist rGen = let
(g1, g2) = R.split rGen
in
(sample dist g1) : (nSamples (n-1) dist g2)
{-- Examples --}
-- We consider a dice
data Dice = One | Two | Three | Four | Five | Six
deriving (Enum, Eq, Show, Read, Ord)
-- Simple example of support
example01a =
let dist :: P Dice
dist = uniform [One .. Six]
in support dist
-- Simple example of expectation
example01b =
let dist :: P Dice
dist = uniform [One .. Six]
event s = prob (s == Six)
in expectation event dist
-- Simple example of sampling
example01c =
let dist :: P Dice
dist = uniform [One .. Six]
randGen = R.mkStdGen 42
in map (\(a, p) -> a) (nSamples 10 dist randGen)
{- Using prior distributions -}
example02a =
let dist :: P Dice
dist = do
d <- uniform [One .. Six]
return (if d == Six then One else d)
in support dist
example02b =
let dist :: P Dice
dist = do
d <- uniform [One .. Six]
return (if d == Six then One else d)
event s = prob (s == Six)
in expectation event dist
example02c =
let dist :: P Dice
dist = do
d <- uniform [One .. Six]
return (if d == Six then One else d)
randGen = R.mkStdGen 42
in map (\(a, p) -> a) (nSamples 10 dist randGen)
-- We need a completely enumerable world
walk x =
do bit <- uniform [True, False]
if bit then
return x
else walk (x + 1)
-- This doesn't terminate! Guess why?
example03a = support (walk 0)
example03b = expectation (\x -> prob (x < 5)) (walk 0)
example03c = map (\(a, p) -> a) (nSamples 10 (walk 0) (R.mkStdGen 42))
mc = map (\a -> (head a, length a)) $ group $ sort xs
where
xs = map (\(a, p) -> a) (nSamples 10000 (walk 0) (R.mkStdGen 42))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment