Skip to content

Instantly share code, notes, and snippets.

@sortega
Created August 18, 2015 21:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sortega/10f67893c5b0861b5d23 to your computer and use it in GitHub Desktop.
Save sortega/10f67893c5b0861b5d23 to your computer and use it in GitHub Desktop.
Another incarnation of the discrete probability monad
package distro
import scala.annotation.tailrec
import scala.util.Random
trait Distro[A] {
def eventSet: Set[A]
def probOf(event: A): BigDecimal
def map[B](f: A => B): Distro[B]
def flatMap[B](f: A => Distro[B]): Distro[B]
def filter(f: A => Boolean): Distro[A]
def format(implicit org: Ordering[A]): String
def expected(f: (A) => BigDecimal): BigDecimal
def mean(implicit num: Numeric[A]): BigDecimal = expected(event => num.toDouble(event))
def variance(implicit num: Numeric[A]): BigDecimal = {
val mu = mean
expected { event =>
val diff = num.toDouble(event) - mu
diff * diff
}
}
def nextValue(random: Random): A
}
object Distro {
val Delta = BigDecimal("0.00000001")
def impossible[A]: Impossible[A] = Impossible()
def always[A](event: A) = uniform(event)
def uniform[A](events: A*): Distro[A] = uniform(events.toSet)
def uniform[A](events: Set[A]): Distro[A] =
normalize(events.map(event => event -> BigDecimal(1)).toMap)
def fairDice(sides: Int): Distro[Int] = uniform((1 to sides).toSet)
case class Impossible[A]() extends Distro[A] {
override def eventSet: Set[A] = Set.empty
override def probOf(event: A): BigDecimal = 0
override def flatMap[B](f: (A) => Distro[B]): Distro[B] = Impossible()
override def filter(f: (A) => Boolean): Distro[A] = Impossible()
override def format(implicit org: Ordering[A]): String = "Distro.Impossible"
override def map[B](f: (A) => B): Distro[B] = Impossible()
override def expected(f: (A) => BigDecimal): BigDecimal = ???
override def nextValue(random: Random): A = ???
}
case class Discrete[A](probs: Map[A, BigDecimal]) extends Distro[A] {
require((probs.values.sum - BigDecimal(1)).abs < Delta, s"Not adding up to 1: $probs")
override def eventSet: Set[A] = probs.keySet
override def probOf(event: A): BigDecimal = probs.getOrElse(event, 0)
override def map[B](f: A => B): Distro[B] = aggregateEvents(for {
(event, prob) <- probs.toSeq
} yield f(event) -> prob)
override def flatMap[B](f: A => Distro[B]): Distro[B] = aggregateEvents(for {
(initialEvent, priorProb) <- probs.toSeq
nextDistro = f(initialEvent) if nextDistro.eventSet.nonEmpty
nextEvent <- nextDistro.eventSet
} yield nextEvent -> (priorProb * nextDistro.probOf(nextEvent)))
override def filter(f: A => Boolean): Distro[A] = {
val survivingProbs = probs.filterKeys(f)
aggregateEvents(survivingProbs.toSeq)
}
override def format(implicit org: Ordering[A]): String =
probs.keys.toSeq.sorted.map { event =>
s" $event,\t${probs(event)},"
}.mkString("Distro.Discrete(\n", "\n", "\n)")
override def expected(f: (A) => BigDecimal): BigDecimal = probs.collect {
case (event, prob) => f(event) * prob
}.sum
override def nextValue(random: Random): A = {
@tailrec
def choose(deviate: BigDecimal, entries: Seq[(A, BigDecimal)]): A =
if (deviate <= entries.head._2) entries.head._1
else choose(deviate - entries.head._2, entries.tail)
choose(BigDecimal(random.nextDouble()), probs.toSeq)
}
}
trait Hypothesis[A] {
def probOf(event: A): BigDecimal
}
implicit class BayesianDistro[A](val distro: Distro[Hypothesis[A]]) extends AnyVal {
def posterior(observation: A): Distro[Hypothesis[A]] = aggregateEvents(for {
hypothesis <- distro.eventSet.toSeq
priorProb = distro.probOf(hypothesis)
modelProb = hypothesis.probOf(observation)
} yield hypothesis -> (priorProb * modelProb))
def posterior(observations: A*): Distro[Hypothesis[A]] =
observations.foldLeft(distro)(_.posterior(_))
}
private def normalize[T](probs: Map[T, BigDecimal]): Distro[T] = {
val totalWeight = probs.values.sum
if (totalWeight > 0) Discrete(probs.filter(_._2 > 0).mapValues(_ / totalWeight))
else Impossible()
}
private def aggregateEvents[T](weightedEvents: Seq[(T, BigDecimal)]): Distro[T] =
normalize(weightedEvents.groupBy(_._1).mapValues(events => events.map(_._2).sum))
}
package distro
import scala.util.Random
import org.scalatest.{FlatSpec, ShouldMatchers}
class DistroTest extends FlatSpec with ShouldMatchers {
sealed trait Toss
case object Heads extends Toss
case object Tails extends Toss
val coinToss = Distro.uniform(Tails, Heads)
val diceRoll = Distro.fairDice(6)
"A discrete prob distribution" should "be constructed uniformly" in {
coinToss.probOf(Tails) shouldBe 0.5
coinToss.probOf(Heads) shouldBe 0.5
}
it should "map events" in {
coinToss.map(_.toString.head) shouldBe Distro.uniform('H', 'T')
diceRoll.map(_ % 2 == 0) shouldBe Distro.uniform(true, false)
}
it should "flatmap events" in {
val twoDicesRoll = for {
roll1 <- diceRoll
roll2 <- diceRoll
} yield roll1 + roll2
twoDicesRoll.probOf(7) shouldBe BigDecimal(1) / 6
val ludoSteps: Distro[Int] = for {
roll1 <- diceRoll
roll2 <- diceRoll
roll3 <- diceRoll
} yield (roll1, roll2, roll3) match {
case (6, 6, 6) => 0
case (6, 6, other) => 12 + other
case (6, other, _) => 6 + other
case (other, _, _) => other
}
ludoSteps.probOf(0) should be < BigDecimal(0.005)
}
it should "filter events" in {
val twoDicesRoll = for {
roll1 <- diceRoll
roll2 <- diceRoll
if roll2 > roll1
} yield (roll1, roll2)
twoDicesRoll.probOf(6 -> 1) shouldBe 0
}
it should "do bayesian reasoning" in {
case object RiggedCoin extends Distro.Hypothesis[Toss] {
override def probOf(event: Toss): BigDecimal = if (event == Heads) 1 else 0
}
case object FairCoin extends Distro.Hypothesis[Toss] {
override def probOf(event: Toss): BigDecimal = 0.5
}
val initialDistro = Distro.Discrete[Distro.Hypothesis[Toss]](Map(
FairCoin -> BigDecimal(0.95),
RiggedCoin -> BigDecimal(0.05)
))
initialDistro.posterior(Tails).probOf(RiggedCoin) shouldBe 0
initialDistro.posterior(Heads, Heads, Heads, Heads, Heads, Heads, Heads)
.probOf(RiggedCoin) should be > BigDecimal(0.8)
}
it should "generate random values" in {
val random = new Random()
val samples = 10000
val empiricalFrequencies = Seq.fill(samples)(diceRoll.nextValue(random))
.groupBy(identity)
.mapValues(_.size.toDouble / samples)
empiricalFrequencies.values.min shouldBe 0.16 +- 0.1
empiricalFrequencies.values.max shouldBe 0.16 +- 0.1
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment