Last active
August 29, 2015 14:03
-
-
Save timcharper/3277b10ad866676284e2 to your computer and use it in GitHub Desktop.
Functional Scala Implementation of Bayes Graph Solver (https://www.youtube.com/watch?v=pPTLK5hFGnQ)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/sh | |
exec scala "$0" "$@" | |
!# | |
object Factor extends Enumeration { | |
val Econ = Value("Econ") | |
val Stock = Value("Stock") | |
} | |
import Factor._ | |
class Node( | |
val probability: Double, | |
val factor: Option[Factor.Value], | |
val outcomes: Map[String, Node] | |
) { | |
def calcProbability(queryFactor: Factor.Value, queryOutcome: String, given: Map[Factor.Value, String]): Double = { | |
def calcChild(node: Node) = node.calcProbability(queryFactor, queryOutcome, given) | |
lazy val childProbabilities = outcomes.mapValues(calcChild) | |
factor match { | |
// Leaf node? | |
case None => | |
probability | |
// Given value? Prune traversal to the matching branch only. | |
case Some(f) if given.contains(f) => | |
probability * (outcomes.get(given(f)).map(calcChild).getOrElse(0.0)) | |
// Query factor? reduce the branches down using bayes theorem | |
case Some(f) if (queryFactor == f) => | |
childProbabilities.getOrElse(queryOutcome, 0.0) / childProbabilities.values.sum | |
// Otherwise... unspecified factor and not queried. | |
case _ => { | |
// Sum up the childProbabilities | |
for { | |
(outcome, node) <- outcomes | |
} yield childProbabilities.getOrElse(outcome, 0.0) | |
}.sum | |
} | |
} | |
} | |
// Factory methods for constructing nodes. Makes the specification of a graph nice and concise. | |
object Node { | |
def apply(probability: Double): Node = | |
new Node(probability, None, Map.empty) | |
def apply(probability: Double, factor: Factor.Value, outcomes: (String, Node)*): Node = | |
new Node(probability, Some(factor), outcomes.toMap) | |
} | |
val graph = | |
Node(1.0, Econ, // the root accounts for 100% of all child observations.. | |
"grow" -> | |
Node(0.7, Stock, | |
"up" -> | |
Node(0.8), | |
"down" -> | |
Node(0.2)), | |
"slow" -> | |
Node(0.3, Stock, | |
"up" -> | |
Node(0.3), | |
"down" -> | |
Node(0.7))) | |
println(graph.calcProbability( | |
Econ, "grow", | |
given = Map(Stock -> "up") | |
)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment