Skip to content

Instantly share code, notes, and snippets.

@sortega
Created April 30, 2013 13:38
Show Gist options
  • Save sortega/5488755 to your computer and use it in GitHub Desktop.
Save sortega/5488755 to your computer and use it in GitHub Desktop.
Discrete probability distribution monad in Scala.
import scala.collection.immutable
class Distro[T](probs: Seq[(T, Double)]) {
require((probs.map(_._2).sum - 1.0d).abs < Distro.delta)
val probMap: Map[T, Double] = mergeProbs(probs)
def map[B](f: T => B): Distro[B] =
new Distro(for ((event, prob) <- probMap.toSeq) yield (f(event), prob))
def flatMap[B](f: T => Distro[B]): Distro[B] =
new Distro(for ((event, prob) <- probMap.toSeq;
(newEvent, newProb) <- f(event).probMap.toSeq)
yield (newEvent, newProb * prob))
def filter(f: T => Boolean): Distro[T] = {
val selectedEvents = (for ((event, prob) <- probMap.toSeq; if (f(event)))
yield (event, prob)).toVector
val selectedProbability = selectedEvents.map(_._2).sum
new Distro(selectedEvents.map{
case (event, prob) => (event, prob / selectedProbability)
})
}
override def toString = probMap.mkString("{", ",", "}")
private def mergeProbs[T](probs: Seq[(T, Double)]): Map[T, Double] = {
probs.foldLeft(immutable.Map.empty[T, Double])((map, pair) => pair match {
case (event, prob) => {
val currentProb = map.getOrElse(event, 0.0)
map.updated(event, currentProb + prob)
}
})
}
}
object Distro {
val delta = 0.001
def apply[T](events: (T, Double)*) = new Distro(events)
def apply[T](events: Map[T, Double]) = new Distro(events.toSeq)
def uniform[T](elem: T*): Distro[T] = {
val elems = elem.toList
val prob = 1.0 / elems.length.toDouble
new Distro(elems.map(elem => (elem, prob)))
}
}
val d1 = Distro.uniform("male", "female")
// d1: Distro[String] = {male -> 0.5,female -> 0.5}
def beingMother(gender: String): Distro[Boolean] = gender match {
case "male" => Distro.uniform(false)
case "female" => Distro(true -> 0.1, false -> 0.9)
}
d1.flatMap(beingMother)
// res0: Distro[Boolean] = {false -> 0.95,true -> 0.05}
val diceRoll = Distro.uniform(1 to 6: _*)
// diceRoll: Distro[Int] = {5 -> 0.16666666666666666,1 -> 0.16666666666666666,6 -> 0.16666666666666666,2 -> 0.16666666666666666,3 -> 0.16666666666666666,4 -> 0.16666666666666666}
// Distribution when trying to get even doubles on a two dices roll
for (roll1 <- diceRoll;
roll2 <- diceRoll)
yield (roll1 == roll2) && (roll1 % 2 == 0)
// res1: Distro[Boolean] = {false -> 0.9166666666666664,true -> 0.08333333333333333}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment