Skip to content

Instantly share code, notes, and snippets.

@cberzan
Created July 16, 2014 16:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cberzan/b89749f634a947ef9503 to your computer and use it in GitHub Desktop.
Save cberzan/b89749f634a947ef9503 to your computer and use it in GitHub Desktop.
import com.cra.figaro.language._
import com.cra.figaro.library._
import com.cra.figaro.library.compound._
import com.cra.figaro.algorithm.sampling._
import scala.math
object Node extends Enumeration {
type Node = Value
val S, A, B, C, T_a, T_b, T_c, T_d, T_e = Value
}
import Node._
class Model {
val nonTerminals = Set(S, A, B, C)
def isTerminal(node: Node) = !(nonTerminals contains node)
def nodeToString(node: Node): String = {
node match {
case S => "S"
case A => "A"
case B => "B"
case C => "C"
case T_a => "a"
case T_b => "b"
case T_c => "c"
case T_d => "d"
case T_e => "e"
}
}
def nodesToString(nodes: Seq[Node]): String =
nodes.map(nodeToString).mkString("")
/**
* Return possible expansions for a node, and their probabilities.
*
* Deterministic function.
*/
def expansions(node: Node): Seq[(Seq[Node], Double)] = {
node match {
case S => Seq((Seq(A, B), 0.25),
(Seq(B, C), 0.2),
(Seq(A, C), 0.4),
(Seq(C, A), 0.15))
case A => Seq((Seq(T_a), 0.05),
(Seq(T_b), 0.3),
(Seq(S), 0.65))
case B => Seq((Seq(T_b), 0.5),
(Seq(T_c), 0.3),
(Seq(T_d), 0.2))
case C => Seq((Seq(T_d), 0.35),
(Seq(T_e), 0.1),
(Seq(S), 0.55))
case default => Seq((Seq(node), 1.0))
}
}
/**
* Split list of nodes at the first non-terminal node.
*
* Examples:
* pickFirstNonTerminal(abSdAB) = Some((ab, S, dAB)).
* pickFirstNonTerminal(ab) = None
*
* Returns None if all symbols are terminal.
*
* Deterministic function.
*/
def pickFirstNonTerminal(nodes: Seq[Node]):
Option[(Seq[Node], Node, Seq[Node])] = {
val (left, rest) = nodes.span(node => isTerminal(node))
if (rest.isEmpty) {
None
} else {
val nonTerm :: right = rest
Some(left, nonTerm, right)
}
}
/**
* Return possible expansions of the given list of nodes, and their
* probabilities. Expands only the first non-terminal. Returns None if all
* symbols are terminal.
*
* Examples:
* expandFirstNonTerminal(aBcS) =
* Some([(abcS, 0.5), (accS, 0.3), (adcS, 0.2)])
* expandFirstNonTerminal(abc) = None
*
* Deterministic function.
*/
def expandFirstNonTerminal(nodes: Seq[Node]):
Option[Seq[(Seq[Node], Double)]] = {
pickFirstNonTerminal(nodes) match {
case Some((left, nonTerm, right)) => {
assert(!isTerminal(nonTerm))
val exps = expansions(nonTerm)
Some(exps.map({
case (exp: Seq[Node], prob: Double) =>
(left ++ exp ++ right, prob)
}))
}
case None => None
}
}
/**
* Expand the first non-terminal, or return None if all nodes are terminals.
*
* Random function (uses Figaro elements).
*/
def expandOne(nodes: Seq[Node]): Option[Element[Seq[Node]]] = {
expandFirstNonTerminal(nodes) match {
case None => None
case Some(exps: Seq[(Seq[Node], Double)]) =>
Some(Select(exps.map(_.swap):_*))
}
}
/**
* Repeatedly expand the first non-terminal until there is nothing left to
* expand.
*
* Random function (uses Figaro elements).
*/
def generate(nodes: Seq[Node]): Element[Seq[Node]] = {
// HACK: Stop expanding if sentence is too long.
if (nodes.length > 4) {
return Constant(nodes)
}
// Uncomment to see derivation step by step:
println(nodesToString(nodes))
// DEBUG
val elems = Universe.universe.activeElements
println("have " + elems.size + " elements")
// val sortedElems = elems.sortBy(System.identityHashCode)
// sortedElems.take(10).foreach(elem =>
// println(" " + System.identityHashCode(elem) + ": " + elem))
expandOne(nodes) match {
case None =>
Constant(nodes) // only terminals left
case Some(newNodes) =>
Chain(newNodes, generate) // keep expanding
}
}
val sentence = generate(Seq(S))
}
object PCFG {
def testExpansions: Unit = {
val model = new Model
Node.values.foreach((node: Node) => {
val exps = model.expansions(node)
val probSum = exps.map(e => e._2).sum
val _ = if (math.abs(probSum - 1.0) > 1e-10) {
throw new Exception("Prob does not sum to 1 for node " + node)
}
if (model.isTerminal(node)) {
if (!(exps.length == 1 && exps.head == (Seq(node), 1.0))) {
throw new Exception(
"Terminal " + node + " behaves like a non-terminal.")
}
} else {
if (exps.length == 1 && exps.head == (Seq(node), 1.0)) {
throw new Exception(
"Non-terminal " + node + " behaves like a terminal.")
}
}
})
println("Test passed")
}
def testExpandFirstNonTerminal: Unit = {
val model = new Model
val in0 = Seq(T_a, T_b, T_c)
val out0 = model.expandFirstNonTerminal(in0)
val out0ref = None
if (out0 != out0ref) {
println("out0 : " + out0)
println("out0ref: " + out0ref)
throw new Exception("expandFirstNonTerminal bad")
}
val in1 = Seq(T_a, B, T_c, S)
val out1 = model.expandFirstNonTerminal(in1)
val out1ref = Some(Seq(
(Seq(T_a, T_b, T_c, S), 0.5),
(Seq(T_a, T_c, T_c, S), 0.3),
(Seq(T_a, T_d, T_c, S), 0.2)))
if (out1 != out1ref) {
println("out1 : " + out1)
println("out1ref: " + out1ref)
throw new Exception("expandFirstNonTerminal bad")
}
val in2 = Seq(C, A)
val out2 = model.expandFirstNonTerminal(in2)
val out2ref = Some(Seq(
(Seq(T_d, A), 0.35),
(Seq(T_e, A), 0.1),
(Seq(S, A), 0.55)))
if (out2 != out2ref) {
println("out2 : " + out2)
println("out2ref: " + out2ref)
throw new Exception("expandFirstNonTerminal bad")
}
println("Test passed")
}
def main(args: Array[String]) {
testExpansions
testExpandFirstNonTerminal
// Generating a single sentence:
// (Note: for loop + model.sentence.unset doesn't seem to forget,
// so I don't know how to generate multiple samples...)
// val model = new Model
// model.sentence.generate
// println(model.nodesToString(model.sentence.value))
val model = new Model
val alg = Importance(10, model.sentence)
for (i <- 1 to 3) {
println("------------------------ iteration " + i)
println(alg.sample)
}
// alg.start()
// alg.stop()
// alg.distribution(model.sentence).foreach
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment