Created
August 18, 2015 21:40
-
-
Save sortega/10f67893c5b0861b5d23 to your computer and use it in GitHub Desktop.
Another incarnation of the discrete probability monad
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package distro | |
import scala.annotation.tailrec | |
import scala.util.Random | |
trait Distro[A] { | |
def eventSet: Set[A] | |
def probOf(event: A): BigDecimal | |
def map[B](f: A => B): Distro[B] | |
def flatMap[B](f: A => Distro[B]): Distro[B] | |
def filter(f: A => Boolean): Distro[A] | |
def format(implicit org: Ordering[A]): String | |
def expected(f: (A) => BigDecimal): BigDecimal | |
def mean(implicit num: Numeric[A]): BigDecimal = expected(event => num.toDouble(event)) | |
def variance(implicit num: Numeric[A]): BigDecimal = { | |
val mu = mean | |
expected { event => | |
val diff = num.toDouble(event) - mu | |
diff * diff | |
} | |
} | |
def nextValue(random: Random): A | |
} | |
object Distro { | |
val Delta = BigDecimal("0.00000001") | |
def impossible[A]: Impossible[A] = Impossible() | |
def always[A](event: A) = uniform(event) | |
def uniform[A](events: A*): Distro[A] = uniform(events.toSet) | |
def uniform[A](events: Set[A]): Distro[A] = | |
normalize(events.map(event => event -> BigDecimal(1)).toMap) | |
def fairDice(sides: Int): Distro[Int] = uniform((1 to sides).toSet) | |
case class Impossible[A]() extends Distro[A] { | |
override def eventSet: Set[A] = Set.empty | |
override def probOf(event: A): BigDecimal = 0 | |
override def flatMap[B](f: (A) => Distro[B]): Distro[B] = Impossible() | |
override def filter(f: (A) => Boolean): Distro[A] = Impossible() | |
override def format(implicit org: Ordering[A]): String = "Distro.Impossible" | |
override def map[B](f: (A) => B): Distro[B] = Impossible() | |
override def expected(f: (A) => BigDecimal): BigDecimal = ??? | |
override def nextValue(random: Random): A = ??? | |
} | |
case class Discrete[A](probs: Map[A, BigDecimal]) extends Distro[A] { | |
require((probs.values.sum - BigDecimal(1)).abs < Delta, s"Not adding up to 1: $probs") | |
override def eventSet: Set[A] = probs.keySet | |
override def probOf(event: A): BigDecimal = probs.getOrElse(event, 0) | |
override def map[B](f: A => B): Distro[B] = aggregateEvents(for { | |
(event, prob) <- probs.toSeq | |
} yield f(event) -> prob) | |
override def flatMap[B](f: A => Distro[B]): Distro[B] = aggregateEvents(for { | |
(initialEvent, priorProb) <- probs.toSeq | |
nextDistro = f(initialEvent) if nextDistro.eventSet.nonEmpty | |
nextEvent <- nextDistro.eventSet | |
} yield nextEvent -> (priorProb * nextDistro.probOf(nextEvent))) | |
override def filter(f: A => Boolean): Distro[A] = { | |
val survivingProbs = probs.filterKeys(f) | |
aggregateEvents(survivingProbs.toSeq) | |
} | |
override def format(implicit org: Ordering[A]): String = | |
probs.keys.toSeq.sorted.map { event => | |
s" $event,\t${probs(event)}," | |
}.mkString("Distro.Discrete(\n", "\n", "\n)") | |
override def expected(f: (A) => BigDecimal): BigDecimal = probs.collect { | |
case (event, prob) => f(event) * prob | |
}.sum | |
override def nextValue(random: Random): A = { | |
@tailrec | |
def choose(deviate: BigDecimal, entries: Seq[(A, BigDecimal)]): A = | |
if (deviate <= entries.head._2) entries.head._1 | |
else choose(deviate - entries.head._2, entries.tail) | |
choose(BigDecimal(random.nextDouble()), probs.toSeq) | |
} | |
} | |
trait Hypothesis[A] { | |
def probOf(event: A): BigDecimal | |
} | |
implicit class BayesianDistro[A](val distro: Distro[Hypothesis[A]]) extends AnyVal { | |
def posterior(observation: A): Distro[Hypothesis[A]] = aggregateEvents(for { | |
hypothesis <- distro.eventSet.toSeq | |
priorProb = distro.probOf(hypothesis) | |
modelProb = hypothesis.probOf(observation) | |
} yield hypothesis -> (priorProb * modelProb)) | |
def posterior(observations: A*): Distro[Hypothesis[A]] = | |
observations.foldLeft(distro)(_.posterior(_)) | |
} | |
private def normalize[T](probs: Map[T, BigDecimal]): Distro[T] = { | |
val totalWeight = probs.values.sum | |
if (totalWeight > 0) Discrete(probs.filter(_._2 > 0).mapValues(_ / totalWeight)) | |
else Impossible() | |
} | |
private def aggregateEvents[T](weightedEvents: Seq[(T, BigDecimal)]): Distro[T] = | |
normalize(weightedEvents.groupBy(_._1).mapValues(events => events.map(_._2).sum)) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package distro | |
import scala.util.Random | |
import org.scalatest.{FlatSpec, ShouldMatchers} | |
class DistroTest extends FlatSpec with ShouldMatchers { | |
sealed trait Toss | |
case object Heads extends Toss | |
case object Tails extends Toss | |
val coinToss = Distro.uniform(Tails, Heads) | |
val diceRoll = Distro.fairDice(6) | |
"A discrete prob distribution" should "be constructed uniformly" in { | |
coinToss.probOf(Tails) shouldBe 0.5 | |
coinToss.probOf(Heads) shouldBe 0.5 | |
} | |
it should "map events" in { | |
coinToss.map(_.toString.head) shouldBe Distro.uniform('H', 'T') | |
diceRoll.map(_ % 2 == 0) shouldBe Distro.uniform(true, false) | |
} | |
it should "flatmap events" in { | |
val twoDicesRoll = for { | |
roll1 <- diceRoll | |
roll2 <- diceRoll | |
} yield roll1 + roll2 | |
twoDicesRoll.probOf(7) shouldBe BigDecimal(1) / 6 | |
val ludoSteps: Distro[Int] = for { | |
roll1 <- diceRoll | |
roll2 <- diceRoll | |
roll3 <- diceRoll | |
} yield (roll1, roll2, roll3) match { | |
case (6, 6, 6) => 0 | |
case (6, 6, other) => 12 + other | |
case (6, other, _) => 6 + other | |
case (other, _, _) => other | |
} | |
ludoSteps.probOf(0) should be < BigDecimal(0.005) | |
} | |
it should "filter events" in { | |
val twoDicesRoll = for { | |
roll1 <- diceRoll | |
roll2 <- diceRoll | |
if roll2 > roll1 | |
} yield (roll1, roll2) | |
twoDicesRoll.probOf(6 -> 1) shouldBe 0 | |
} | |
it should "do bayesian reasoning" in { | |
case object RiggedCoin extends Distro.Hypothesis[Toss] { | |
override def probOf(event: Toss): BigDecimal = if (event == Heads) 1 else 0 | |
} | |
case object FairCoin extends Distro.Hypothesis[Toss] { | |
override def probOf(event: Toss): BigDecimal = 0.5 | |
} | |
val initialDistro = Distro.Discrete[Distro.Hypothesis[Toss]](Map( | |
FairCoin -> BigDecimal(0.95), | |
RiggedCoin -> BigDecimal(0.05) | |
)) | |
initialDistro.posterior(Tails).probOf(RiggedCoin) shouldBe 0 | |
initialDistro.posterior(Heads, Heads, Heads, Heads, Heads, Heads, Heads) | |
.probOf(RiggedCoin) should be > BigDecimal(0.8) | |
} | |
it should "generate random values" in { | |
val random = new Random() | |
val samples = 10000 | |
val empiricalFrequencies = Seq.fill(samples)(diceRoll.nextValue(random)) | |
.groupBy(identity) | |
.mapValues(_.size.toDouble / samples) | |
empiricalFrequencies.values.min shouldBe 0.16 +- 0.1 | |
empiricalFrequencies.values.max shouldBe 0.16 +- 0.1 | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment