Created
July 16, 2014 16:16
-
-
Save cberzan/b89749f634a947ef9503 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.cra.figaro.language._ | |
import com.cra.figaro.library._ | |
import com.cra.figaro.library.compound._ | |
import com.cra.figaro.algorithm.sampling._ | |
import scala.math | |
object Node extends Enumeration { | |
type Node = Value | |
val S, A, B, C, T_a, T_b, T_c, T_d, T_e = Value | |
} | |
import Node._ | |
class Model { | |
val nonTerminals = Set(S, A, B, C) | |
def isTerminal(node: Node) = !(nonTerminals contains node) | |
def nodeToString(node: Node): String = { | |
node match { | |
case S => "S" | |
case A => "A" | |
case B => "B" | |
case C => "C" | |
case T_a => "a" | |
case T_b => "b" | |
case T_c => "c" | |
case T_d => "d" | |
case T_e => "e" | |
} | |
} | |
def nodesToString(nodes: Seq[Node]): String = | |
nodes.map(nodeToString).mkString("") | |
/** | |
* Return possible expansions for a node, and their probabilities. | |
* | |
* Deterministic function. | |
*/ | |
def expansions(node: Node): Seq[(Seq[Node], Double)] = { | |
node match { | |
case S => Seq((Seq(A, B), 0.25), | |
(Seq(B, C), 0.2), | |
(Seq(A, C), 0.4), | |
(Seq(C, A), 0.15)) | |
case A => Seq((Seq(T_a), 0.05), | |
(Seq(T_b), 0.3), | |
(Seq(S), 0.65)) | |
case B => Seq((Seq(T_b), 0.5), | |
(Seq(T_c), 0.3), | |
(Seq(T_d), 0.2)) | |
case C => Seq((Seq(T_d), 0.35), | |
(Seq(T_e), 0.1), | |
(Seq(S), 0.55)) | |
case default => Seq((Seq(node), 1.0)) | |
} | |
} | |
/** | |
* Split list of nodes at the first non-terminal node. | |
* | |
* Examples: | |
* pickFirstNonTerminal(abSdAB) = Some((ab, S, dAB)). | |
* pickFirstNonTerminal(ab) = None | |
* | |
* Returns None if all symbols are terminal. | |
* | |
* Deterministic function. | |
*/ | |
def pickFirstNonTerminal(nodes: Seq[Node]): | |
Option[(Seq[Node], Node, Seq[Node])] = { | |
val (left, rest) = nodes.span(node => isTerminal(node)) | |
if (rest.isEmpty) { | |
None | |
} else { | |
val nonTerm :: right = rest | |
Some(left, nonTerm, right) | |
} | |
} | |
/** | |
* Return possible expansions of the given list of nodes, and their | |
* probabilities. Expands only the first non-terminal. Returns None if all | |
* symbols are terminal. | |
* | |
* Examples: | |
* expandFirstNonTerminal(aBcS) = | |
* Some([(abcS, 0.5), (accS, 0.3), (adcS, 0.2)]) | |
* expandFirstNonTerminal(abc) = None | |
* | |
* Deterministic function. | |
*/ | |
def expandFirstNonTerminal(nodes: Seq[Node]): | |
Option[Seq[(Seq[Node], Double)]] = { | |
pickFirstNonTerminal(nodes) match { | |
case Some((left, nonTerm, right)) => { | |
assert(!isTerminal(nonTerm)) | |
val exps = expansions(nonTerm) | |
Some(exps.map({ | |
case (exp: Seq[Node], prob: Double) => | |
(left ++ exp ++ right, prob) | |
})) | |
} | |
case None => None | |
} | |
} | |
/** | |
* Expand the first non-terminal, or return None if all nodes are terminals. | |
* | |
* Random function (uses Figaro elements). | |
*/ | |
def expandOne(nodes: Seq[Node]): Option[Element[Seq[Node]]] = { | |
expandFirstNonTerminal(nodes) match { | |
case None => None | |
case Some(exps: Seq[(Seq[Node], Double)]) => | |
Some(Select(exps.map(_.swap):_*)) | |
} | |
} | |
/** | |
* Repeatedly expand the first non-terminal until there is nothing left to | |
* expand. | |
* | |
* Random function (uses Figaro elements). | |
*/ | |
def generate(nodes: Seq[Node]): Element[Seq[Node]] = { | |
// HACK: Stop expanding if sentence is too long. | |
if (nodes.length > 4) { | |
return Constant(nodes) | |
} | |
// Uncomment to see derivation step by step: | |
println(nodesToString(nodes)) | |
// DEBUG | |
val elems = Universe.universe.activeElements | |
println("have " + elems.size + " elements") | |
// val sortedElems = elems.sortBy(System.identityHashCode) | |
// sortedElems.take(10).foreach(elem => | |
// println(" " + System.identityHashCode(elem) + ": " + elem)) | |
expandOne(nodes) match { | |
case None => | |
Constant(nodes) // only terminals left | |
case Some(newNodes) => | |
Chain(newNodes, generate) // keep expanding | |
} | |
} | |
val sentence = generate(Seq(S)) | |
} | |
object PCFG { | |
def testExpansions: Unit = { | |
val model = new Model | |
Node.values.foreach((node: Node) => { | |
val exps = model.expansions(node) | |
val probSum = exps.map(e => e._2).sum | |
val _ = if (math.abs(probSum - 1.0) > 1e-10) { | |
throw new Exception("Prob does not sum to 1 for node " + node) | |
} | |
if (model.isTerminal(node)) { | |
if (!(exps.length == 1 && exps.head == (Seq(node), 1.0))) { | |
throw new Exception( | |
"Terminal " + node + " behaves like a non-terminal.") | |
} | |
} else { | |
if (exps.length == 1 && exps.head == (Seq(node), 1.0)) { | |
throw new Exception( | |
"Non-terminal " + node + " behaves like a terminal.") | |
} | |
} | |
}) | |
println("Test passed") | |
} | |
def testExpandFirstNonTerminal: Unit = { | |
val model = new Model | |
val in0 = Seq(T_a, T_b, T_c) | |
val out0 = model.expandFirstNonTerminal(in0) | |
val out0ref = None | |
if (out0 != out0ref) { | |
println("out0 : " + out0) | |
println("out0ref: " + out0ref) | |
throw new Exception("expandFirstNonTerminal bad") | |
} | |
val in1 = Seq(T_a, B, T_c, S) | |
val out1 = model.expandFirstNonTerminal(in1) | |
val out1ref = Some(Seq( | |
(Seq(T_a, T_b, T_c, S), 0.5), | |
(Seq(T_a, T_c, T_c, S), 0.3), | |
(Seq(T_a, T_d, T_c, S), 0.2))) | |
if (out1 != out1ref) { | |
println("out1 : " + out1) | |
println("out1ref: " + out1ref) | |
throw new Exception("expandFirstNonTerminal bad") | |
} | |
val in2 = Seq(C, A) | |
val out2 = model.expandFirstNonTerminal(in2) | |
val out2ref = Some(Seq( | |
(Seq(T_d, A), 0.35), | |
(Seq(T_e, A), 0.1), | |
(Seq(S, A), 0.55))) | |
if (out2 != out2ref) { | |
println("out2 : " + out2) | |
println("out2ref: " + out2ref) | |
throw new Exception("expandFirstNonTerminal bad") | |
} | |
println("Test passed") | |
} | |
def main(args: Array[String]) { | |
testExpansions | |
testExpandFirstNonTerminal | |
// Generating a single sentence: | |
// (Note: for loop + model.sentence.unset doesn't seem to forget, | |
// so I don't know how to generate multiple samples...) | |
// val model = new Model | |
// model.sentence.generate | |
// println(model.nodesToString(model.sentence.value)) | |
val model = new Model | |
val alg = Importance(10, model.sentence) | |
for (i <- 1 to 3) { | |
println("------------------------ iteration " + i) | |
println(alg.sample) | |
} | |
// alg.start() | |
// alg.stop() | |
// alg.distribution(model.sentence).foreach | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment