Created
March 24, 2015 17:32
-
-
Save strubell/b82f4fa6bdc3c83340a3 to your computer and use it in GitHub Desktop.
ChainNerDemo that works with unseen labels at test time
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cc.factorie.tutorial | |
import cc.factorie._ | |
import java.io.File | |
import cc.factorie.variable._ | |
import cc.factorie.model.{Parameters, DotTemplateWithStatistics2, DotTemplateWithStatistics1, TemplateModel} | |
import cc.factorie.infer.{BPSummary, BP, IteratedConditionalModes, GibbsSampler} | |
/** A demonstration of training a linear-chain CRF for named entity recognition. | |
Prints various diagnostics suitable to a demo. | |
@author Andrew McCallum */ | |
object ChainNERDemo { | |
// The variable classes | |
object TokenDomain extends CategoricalVectorDomain[String] | |
class Token(val word:String, features:Seq[String], labelString:String) extends BinaryFeatureVectorVariable[String] with ChainLink[Token,Sentence] { | |
def domain = TokenDomain | |
override def skipNonCategories = domain.dimensionDomain.frozen | |
val label: Label = new Label(labelString, this) | |
this ++= features | |
} | |
object LabelDomain extends CategoricalDomain[String] | |
class Label(labelname: String, val token: Token) extends LabeledCategoricalVariable(labelname) { | |
def domain = LabelDomain | |
def hasNext = token.hasNext && token.next.label != null | |
def hasPrev = token.hasPrev && token.prev.label != null | |
def next = token.next.label | |
def prev = token.prev.label | |
} | |
class Sentence extends Chain[Sentence,Token] | |
// The model | |
val model = new TemplateModel with Parameters { | |
// Bias term on each individual label | |
object bias extends DotTemplateWithStatistics1[Label] { | |
val weights = Weights(new la.DenseTensor1(LabelDomain.size)) | |
} | |
// Transition factors between two successive labels | |
object transtion extends DotTemplateWithStatistics2[Label, Label] { | |
val weights = Weights(new la.DenseTensor2(LabelDomain.size, LabelDomain.size)) | |
def unroll1(label: Label) = if (label.hasPrev) Factor(label.prev, label) else Nil | |
def unroll2(label: Label) = if (label.hasNext) Factor(label, label.next) else Nil | |
} | |
// Factor between label and observed token | |
object evidence extends DotTemplateWithStatistics2[Label, Token] { | |
val weights = Weights(new la.DenseTensor2(LabelDomain.size, TokenDomain.dimensionSize)) | |
def unroll1(label: Label) = Factor(label, label.token) | |
def unroll2(token: Token) = throw new Error("Token values shouldn't change") | |
} | |
this += evidence | |
this += bias | |
this += transtion | |
} | |
// The training objective | |
val objective = new HammingTemplate[Label, Label#TargetType] | |
def main(args: Array[String]): Unit = { | |
implicit val random = new scala.util.Random(0) | |
if (args.length != 2) throw new Error("Usage: ChainNERDemo trainfile testfile") | |
// Read in the data | |
val trainSentences = load(args(0)) | |
val testSentences = load(args(1)) | |
// Get the variables to be inferred | |
val trainLabels = trainSentences.flatMap(_.links.map(_.label)).take(50000) //.take(30000) | |
// val testLabels = testSentences.flatMap(_.links.map(_.label))//.take(2000) | |
// val allTokens: Seq[Token] = (trainLabels ++ testLabels).map(_.token) | |
// Add features from next and previous tokens | |
// println("Adding offset features...") | |
trainLabels.map(_.token).foreach(t => { | |
if (t.hasPrev) t ++= t.prev.activeCategories.filter(!_.contains('@')).map(_+"@-1") | |
if (t.hasNext) t ++= t.next.activeCategories.filter(!_.contains('@')).map(_+"@+1") | |
}) | |
println("Using "+TokenDomain.dimensionSize+" observable features.") | |
// Print some significant features | |
//println("Most predictive features:") | |
//val pllo = new cc.factorie.app.classify.PerLabelLogOdds(trainSentences.flatMap(_.map(_.label)), (label:Label) => label.token) | |
//for (label <- LabelDomain.values) println(label.category+": "+pllo.top(label, 20)) | |
// Sample and Learn! | |
val startTime = System.currentTimeMillis | |
// (trainLabels ++ testLabels).foreach(_.setRandomly) | |
trainLabels.foreach(_.setRandomly) | |
val learner = new optimize.SampleRankTrainer(new GibbsSampler(model, objective) {temperature=0.1}, new cc.factorie.optimize.AdaGrad) | |
val predictor = new IteratedConditionalModes(model, null) | |
for (i <- 1 to 3) { | |
// println("Iteration "+i) | |
learner.processContexts(trainLabels) | |
// predictor.processAll(testLabels); | |
predictor.processAll(trainLabels) | |
trainLabels.take(20).foreach(printLabel _); println(); println() | |
printDiagnostic(trainLabels.take(400)) | |
//trainLabels.take(20).foreach(label => println("%30s %s %s %f".format(label.token.word, label.targetCategory, label.categoryValue, objective.currentScore(label)))) | |
//println ("Tr50 accuracy = "+ objective.accuracy(trainLabels.take(20))) | |
println ("Train accuracy = "+ objective.accuracy(trainLabels)) | |
// println ("Test accuracy = "+ objective.accuracy(testLabels)) | |
} | |
if (false) { | |
// Use BP Viterbi for prediction | |
for (sentence <- testSentences) | |
BP.inferChainMax(sentence.asSeq.map(_.label), model).setToMaximize(null) | |
//BP.inferChainSum(sentence.asSeq.map(_.label), model).setToMaximize(null) // max-marginal inference | |
for (sentence <- trainSentences.take(10)) { | |
println("---SumProduct---") | |
printTokenMarginals(sentence.asSeq, BP.inferChainSum(sentence.asSeq.map(_.label), model)) | |
println("---MaxProduct---") | |
// printTokenMarginals(sentence.asSeq, BP.inferChainMax(sentence.asSeq.map(_.label), model)) | |
println("---Gibbs Sampling---") | |
// predictor.processAll(testLabels, 2) | |
sentence.asSeq.foreach(token => printLabel(token.label)) | |
} | |
} else { | |
// Use VariableSettingsSampler for prediction | |
//predictor.temperature *= 0.1 | |
// predictor.processAll(testLabels, 2) | |
} | |
TokenDomain.freeze() | |
val testLabels = testSentences.flatMap(_.links.map(_.label))//.take(2000) | |
testLabels.map(_.token).foreach(t => { | |
if (t.hasPrev) t ++= t.prev.activeCategories.filter(!_.contains('@')).map(_+"@-1") | |
if (t.hasNext) t ++= t.next.activeCategories.filter(!_.contains('@')).map(_+"@+1") | |
}) | |
testLabels.foreach(_.setRandomly) | |
predictor.processAll(testLabels, 2) | |
println ("Final Test accuracy = "+ objective.accuracy(testLabels)) | |
//println("norm " + model.weights.twoNorm) | |
println("Finished in " + ((System.currentTimeMillis - startTime) / 1000.0) + " seconds") | |
//for (sentence <- testSentences) BP.inferChainMax(sentence.asSeq.map(_.label), model); println ("MaxBP Test accuracy = "+ objective.accuracy(testLabels)) | |
//for (sentence <- testSentences) BP.inferChainSum(sentence.asSeq.map(_.label), model).setToMaximize(null); println ("SumBP Test accuracy = "+ objective.accuracy(testLabels)) | |
//predictor.processAll(testLabels, 2); println ("Gibbs Test accuracy = "+ objective.accuracy(testLabels)) | |
} | |
def printTokenMarginals(tokens:Seq[Token], summary:BPSummary): Unit = { | |
for (token <- tokens) | |
println(token.word + " " + LabelDomain.categories.zip(summary.marginal(token.label).proportions.asSeq).sortBy(_._2).reverse.mkString(" ")) | |
println() | |
} | |
// Feature extraction | |
def wordToFeatures(word:String, initialFeatures:String*) : Seq[String] = { | |
import scala.collection.mutable.ArrayBuffer | |
val f = new ArrayBuffer[String] | |
f += "W="+word | |
f ++= initialFeatures | |
if (word.length > 3) f += "PRE="+word.substring(0,3) | |
if (Capitalized.findFirstMatchIn(word) != None) f += "CAPITALIZED" | |
if (Numeric.findFirstMatchIn(word) != None) f += "NUMERIC" | |
if (Punctuation.findFirstMatchIn(word) != None) f += "PUNCTUATION" | |
f | |
} | |
val Capitalized = "^[A-Z].*".r | |
val Numeric = "^[0-9]+$".r | |
val Punctuation = "[-,\\.;:?!()]+".r | |
def printLabel(label:Label) : Unit = { | |
println("%-16s TRUE=%-8s PRED=%-8s %s".format(label.token.word, label.target.categoryValue, label.value.category, label.token.toString)) | |
} | |
def printDiagnostic(labels:Seq[Label]) : Unit = { | |
for (label <- labels; if label.intValue != label.domain.index("O")) { | |
if (!label.hasPrev || label.value != label.prev.value) | |
print("%-7s %-7s ".format(if (label.value != label.target.value) label.target.value.category.drop(2) else " ", label.value.category.drop(2))) | |
print(label.token.word+" ") | |
if (!label.hasNext || label.value != label.next.value) println() | |
} | |
println() | |
} | |
def load(filename:String) : Seq[Sentence] = { | |
import scala.io.Source | |
import scala.collection.mutable.ArrayBuffer | |
var wordCount = 0 | |
var sentences = new ArrayBuffer[Sentence] | |
val source = Source.fromFile(new File(filename)) | |
var sentence = new Sentence | |
for (line <- source.getLines()) { | |
if (line.length < 2) { // Sentence boundary | |
sentences += sentence | |
sentence = new Sentence | |
} else if (line.startsWith("-DOCSTART-")) { | |
// Skip document boundaries | |
} else { | |
val fields = line.split(' ') | |
assert(fields.length == 4) | |
val word = fields(0) | |
val pos = fields(1) | |
val label = fields(3).stripLineEnd | |
sentence += new Token(word, wordToFeatures(word,"POS="+pos), label) | |
wordCount += 1 | |
} | |
} | |
println("Loaded "+sentences.length+" sentences with "+wordCount+" words total from file "+filename) | |
sentences | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment