Skip to content

Instantly share code, notes, and snippets.

@timcharper
Last active August 29, 2015 14:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save timcharper/3277b10ad866676284e2 to your computer and use it in GitHub Desktop.
Save timcharper/3277b10ad866676284e2 to your computer and use it in GitHub Desktop.
Functional Scala Implementation of Bayes Graph Solver (https://www.youtube.com/watch?v=pPTLK5hFGnQ)
#!/bin/sh
exec scala "$0" "$@"
!#
object Factor extends Enumeration {
val Econ = Value("Econ")
val Stock = Value("Stock")
}
import Factor._
class Node(
val probability: Double,
val factor: Option[Factor.Value],
val outcomes: Map[String, Node]
) {
def calcProbability(queryFactor: Factor.Value, queryOutcome: String, given: Map[Factor.Value, String]): Double = {
def calcChild(node: Node) = node.calcProbability(queryFactor, queryOutcome, given)
lazy val childProbabilities = outcomes.mapValues(calcChild)
factor match {
// Leaf node?
case None =>
probability
// Given value? Prune traversal to the matching branch only.
case Some(f) if given.contains(f) =>
probability * (outcomes.get(given(f)).map(calcChild).getOrElse(0.0))
// Query factor? reduce the branches down using bayes theorem
case Some(f) if (queryFactor == f) =>
childProbabilities.getOrElse(queryOutcome, 0.0) / childProbabilities.values.sum
// Otherwise... unspecified factor and not queried.
case _ => {
// Sum up the childProbabilities
for {
(outcome, node) <- outcomes
} yield childProbabilities.getOrElse(outcome, 0.0)
}.sum
}
}
}
// Factory methods for constructing nodes. Makes the specification of a graph nice and concise.
object Node {
def apply(probability: Double): Node =
new Node(probability, None, Map.empty)
def apply(probability: Double, factor: Factor.Value, outcomes: (String, Node)*): Node =
new Node(probability, Some(factor), outcomes.toMap)
}
val graph =
Node(1.0, Econ, // the root accounts for 100% of all child observations..
"grow" ->
Node(0.7, Stock,
"up" ->
Node(0.8),
"down" ->
Node(0.2)),
"slow" ->
Node(0.3, Stock,
"up" ->
Node(0.3),
"down" ->
Node(0.7)))
println(graph.calcProbability(
Econ, "grow",
given = Map(Stock -> "up")
))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment