Last active
February 14, 2016 07:19
-
-
Save pierric/a543aad41cd25b34912f to your computer and use it in GitHub Desktop.
The viterbi algorithm for HMM in Rust language
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub trait CorpusModel { | |
fn get_tags(&self) -> Range<u32>; | |
fn get_word_tag_count(&self, word: &String, tag: u32) -> u32; | |
fn get_tag_count_u(&self, tag: u32) -> u32; | |
fn get_tag_count_b(&self, tag1: u32, tag2: u32) -> u32; | |
fn get_tag_count_t(&self, tag1: u32, tag2: u32, tag3: u32) -> u32; | |
fn get_tag_name(&self, tag: u32) -> &String; | |
} | |
pub fn viterbi<M:CorpusModel>(model: &M, sentence: &Vec<String>) -> Vec<u32> { | |
// tags excludes the TAG_START and TAG_STOP | |
let tags : Vec<_> = model.get_tags().map(|v| v as usize).collect(); | |
let dim = tags.len()+2; | |
let zeros = OwnedArray::zeros((dim, dim)); | |
let mut pi_curr = OwnedArray::zeros((dim, dim)); | |
let mut pi_prev = OwnedArray::zeros((dim, dim)); | |
let mut backtrace = OwnedArray::zeros((sentence.len()+1, dim, dim)); | |
// transition probability | |
let p = |t1: usize, t2: usize, t3: usize| -> f64 { | |
model.get_tag_count_t(t1 as u32,t2 as u32,t3 as u32) as f64 / | |
model.get_tag_count_b(t1 as u32,t2 as u32) as f64 | |
}; | |
// word probability | |
let e = |w: &String, t: usize| -> f64 { | |
model.get_word_tag_count(w,t as u32) as f64 / | |
model.get_tag_count_u(t as u32) as f64 | |
}; | |
if sentence.len() == 1 { | |
for t1 in &tags { | |
pi_curr[(*t1, TAG_STOP)] = p(TAG_START, TAG_START, *t1) * p(TAG_START, *t1, TAG_STOP) * e(&sentence[0], *t1); | |
} | |
} | |
else { | |
for t1 in &tags { | |
for t2 in &tags { | |
pi_curr[(*t1, *t2)] = p(TAG_START, TAG_START, *t1) * p(TAG_START, *t1, *t2) * | |
e(&sentence[0], *t1) * e(&sentence[1], *t2); | |
} | |
} | |
for i in 2..sentence.len() { | |
pi_prev.assign(&pi_curr); | |
pi_curr.assign(&zeros); | |
for t1 in &tags { | |
for t2 in &tags { | |
let mut v : f64 = 0.0; | |
for t3 in &tags { | |
let u : f64 = pi_prev[(*t3,*t1)] * p (*t3, *t1, *t2) * e (&sentence[i], *t2); | |
if u >= v { | |
v = u; | |
backtrace[(i,*t1,*t2)] = *t3; | |
} | |
} | |
pi_curr[(*t1,*t2)] = v; | |
} | |
} | |
} | |
pi_prev.assign(&pi_curr); | |
pi_curr.assign(&zeros); | |
for t1 in &tags { | |
let mut v : f64 = 0.0; | |
for t2 in &tags { | |
let u : f64 = pi_prev[(*t2,*t1)] * p (*t2,*t1,TAG_STOP); | |
if u >= v { | |
v = u; | |
backtrace[(sentence.len(), *t1, TAG_STOP)] = *t2; | |
} | |
} | |
pi_curr[(*t1, TAG_STOP)] = v; | |
} | |
} | |
// find the maximum ending tag. | |
let mut max = (0.0, (0,0)); | |
for t1 in &tags { | |
let v = pi_curr[(*t1,TAG_STOP)]; | |
if v >= max.0 { | |
max = (v, (*t1, TAG_STOP)); | |
} | |
} | |
// extract the path | |
let mut path: Vec<u32> = Vec::with_capacity(sentence.len()); | |
let mut tp = max.1; | |
path.push(tp.0 as u32); | |
for i in 1..sentence.len() { | |
let t = backtrace[(sentence.len()+1-i, tp.0, tp.1)]; | |
path.push(t as u32); | |
tp = (t, tp.0); | |
} | |
path.reverse(); | |
path | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment