Skip to content

Instantly share code, notes, and snippets.

@eob
Created November 27, 2010 21:35
Show Gist options
  • Save eob/718298 to your computer and use it in GitHub Desktop.
Save eob/718298 to your computer and use it in GitHub Desktop.
Alignment Update
/**
* log q(Z_i = k) less than or equal to
* log P(Z_i=k) (the prior)
* + log E_ { q(y_i),q(T_k) } Theta^T_REW f_REW(x_i,y_i,T_k)
* - log E_ { q(T_k) } Z_ { x_i } (Theta, T_k)
*/
def update(td: TaggingDistribution, rd: RecordDistribution) = {
// Initialize to all -Infinity
_Z = logEmptyZ()
for (tweet <- 0 until _Z.size) {
updateForTweet(tweet, td, rd)
}
for (tweet <- 0 until _Z.size) {
_Z(tweet) = logNormalize(_Z(tweet))
logger.info("Tweet " + tweet + " after update: " + _Z(tweet).mkString(" "))
}
// Ensure each _Z(tweet) is a probability distribution
checkOK
}
def updateForTweet(tweet: Int, td: TaggingDistribution, rd: RecordDistribution) = {
// Prior (no-op)
// Second term
/*
* First term: Ignored because we're assuming a uniform prior for now
*/
/*
* Second term:
* \sum_{l} \sum_{v} q(y_i = l) q(T_k^l = v) Theta^T_REW f_REW(x_i,y_i,T_k^l)
* Can think of as:
* \sum_{l} \sum_{v} q(y_i = l) q(T_k^l = v) Similiraty(x_i to v)
* Note: RecordDistribution.rewritePotentials returns a list of
* (Record, Token, Label, Probability) where the summing across v has already been done.
* We still need to factor in q(y_i = l)
*/
logger.info("Tweet " + tweet + " before update: " + _Z(tweet).mkString(" "))
// Add in the boosts
for ((record, recordIdx) <- rd.T.zipWithIndex) {
for (boost <- record.thetaRewrite(tweet)) {
// logBoostVal is the log( p(alignment) p(record takes on this field-value) Similarity(field-value to this value) )
// Note: the similarity is not a probability. it's just a number > 0
val logBoostVal = math.log(boost.weightedSimilarityScore * td.Y(tweet)(boost.token)(td.labelIndex(boost.recordField)))
_Z(tweet)(boost.record.index) = logSum(Array(_Z(tweet)(boost.record.index), logBoostVal))
logger.info("R" + recordIdx + " T" + tweet + " Boost Alignment " + boostVal + " Total: " + _Z(tweet)(boost.record.index))
}
}
logger.info("Tweet " + tweet + " before logZ: " + _Z(tweet).mkString(" "))
// Third Term: Normalize
// // NumRecords x NumTransitions x PossibleTransitions
var recordPotentials = List[Array[Array[Double]]]()
// Initialize this to the CRF potentials. We'll then add in the REW potentials for this update
for (k <- 0 until _numRecords) {
val dup = td.crfPotentialsCopy(tweet)
recordPotentials = recordPotentials :+ dup
require (dup.size == td.tweets(tweet).size + 1)
}
// Each recordPotentials should be (Tokens+1) x (Tags)
//
for (record <- rd.T) {
for (boost <- record.thetaRewrite(tweet)) {
val weightedSimilarityScore = boost.similarityScore * record.fields(boost.recordField).counter.count(boost.fieldValue)
TwitterHelper.updatePotentials(
potentials = recordPotentials(boost.record.index),
transitionCache = _transitionCache,
label = boost.recordField,
token = boost.token,
withValue = weightedSimilarityScore
)
}
}
for ((pots, record) <- recordPotentials.zipWithIndex) {
if (_Z(tweet)(record) != scala.Double.NegativeInfinity) {
_forwardBackward.setInput(pots)
val logZ = _forwardBackward.getLogZ()
// logger.info("logz: " + logZ + " : " + math.exp(logZ))
_Z(tweet)(record) -= logZ
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment