Created
November 27, 2010 21:35
-
-
Save eob/718298 to your computer and use it in GitHub Desktop.
Alignment Update
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* log q(Z_i = k) less than or equal to | |
* log P(Z_i=k) (the prior) | |
* + log E_ { q(y_i),q(T_k) } Theta^T_REW f_REW(x_i,y_i,T_k) | |
* - log E_ { q(T_k) } Z_ { x_i } (Theta, T_k) | |
*/ | |
def update(td: TaggingDistribution, rd: RecordDistribution) = { | |
// Initialize to all -Infinity | |
_Z = logEmptyZ() | |
for (tweet <- 0 until _Z.size) { | |
updateForTweet(tweet, td, rd) | |
} | |
for (tweet <- 0 until _Z.size) { | |
_Z(tweet) = logNormalize(_Z(tweet)) | |
logger.info("Tweet " + tweet + " after update: " + _Z(tweet).mkString(" ")) | |
} | |
// Ensure each _Z(tweet) is a probability distribution | |
checkOK | |
} | |
def updateForTweet(tweet: Int, td: TaggingDistribution, rd: RecordDistribution) = { | |
// Prior (no-op) | |
// Second term | |
/* | |
* First term: Ignored because we're assuming a uniform prior for now | |
*/ | |
/* | |
* Second term: | |
* \sum_{l} \sum_{v} q(y_i = l) q(T_k^l = v) Theta^T_REW f_REW(x_i,y_i,T_k^l) | |
* Can think of as: | |
* \sum_{l} \sum_{v} q(y_i = l) q(T_k^l = v) Similiraty(x_i to v) | |
* Note: RecordDistribution.rewritePotentials returns a list of | |
* (Record, Token, Label, Probability) where the summing across v has already been done. | |
* We still need to factor in q(y_i = l) | |
*/ | |
logger.info("Tweet " + tweet + " before update: " + _Z(tweet).mkString(" ")) | |
// Add in the boosts | |
for ((record, recordIdx) <- rd.T.zipWithIndex) { | |
for (boost <- record.thetaRewrite(tweet)) { | |
// logBoostVal is the log( p(alignment) p(record takes on this field-value) Similarity(field-value to this value) ) | |
// Note: the similarity is not a probability. it's just a number > 0 | |
val logBoostVal = math.log(boost.weightedSimilarityScore * td.Y(tweet)(boost.token)(td.labelIndex(boost.recordField))) | |
_Z(tweet)(boost.record.index) = logSum(Array(_Z(tweet)(boost.record.index), logBoostVal)) | |
logger.info("R" + recordIdx + " T" + tweet + " Boost Alignment " + boostVal + " Total: " + _Z(tweet)(boost.record.index)) | |
} | |
} | |
logger.info("Tweet " + tweet + " before logZ: " + _Z(tweet).mkString(" ")) | |
// Third Term: Normalize | |
// // NumRecords x NumTransitions x PossibleTransitions | |
var recordPotentials = List[Array[Array[Double]]]() | |
// Initialize this to the CRF potentials. We'll then add in the REW potentials for this update | |
for (k <- 0 until _numRecords) { | |
val dup = td.crfPotentialsCopy(tweet) | |
recordPotentials = recordPotentials :+ dup | |
require (dup.size == td.tweets(tweet).size + 1) | |
} | |
// Each recordPotentials should be (Tokens+1) x (Tags) | |
// | |
for (record <- rd.T) { | |
for (boost <- record.thetaRewrite(tweet)) { | |
val weightedSimilarityScore = boost.similarityScore * record.fields(boost.recordField).counter.count(boost.fieldValue) | |
TwitterHelper.updatePotentials( | |
potentials = recordPotentials(boost.record.index), | |
transitionCache = _transitionCache, | |
label = boost.recordField, | |
token = boost.token, | |
withValue = weightedSimilarityScore | |
) | |
} | |
} | |
for ((pots, record) <- recordPotentials.zipWithIndex) { | |
if (_Z(tweet)(record) != scala.Double.NegativeInfinity) { | |
_forwardBackward.setInput(pots) | |
val logZ = _forwardBackward.getLogZ() | |
// logger.info("logz: " + logZ + " : " + math.exp(logZ)) | |
_Z(tweet)(record) -= logZ | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment