Skip to content

Instantly share code, notes, and snippets.

@pierric
Last active February 14, 2016 07:19
Show Gist options
  • Save pierric/a543aad41cd25b34912f to your computer and use it in GitHub Desktop.
Save pierric/a543aad41cd25b34912f to your computer and use it in GitHub Desktop.
The viterbi algorithm for HMM in Rust language
pub trait CorpusModel {
fn get_tags(&self) -> Range<u32>;
fn get_word_tag_count(&self, word: &String, tag: u32) -> u32;
fn get_tag_count_u(&self, tag: u32) -> u32;
fn get_tag_count_b(&self, tag1: u32, tag2: u32) -> u32;
fn get_tag_count_t(&self, tag1: u32, tag2: u32, tag3: u32) -> u32;
fn get_tag_name(&self, tag: u32) -> &String;
}
pub fn viterbi<M:CorpusModel>(model: &M, sentence: &Vec<String>) -> Vec<u32> {
// tags excludes the TAG_START and TAG_STOP
let tags : Vec<_> = model.get_tags().map(|v| v as usize).collect();
let dim = tags.len()+2;
let zeros = OwnedArray::zeros((dim, dim));
let mut pi_curr = OwnedArray::zeros((dim, dim));
let mut pi_prev = OwnedArray::zeros((dim, dim));
let mut backtrace = OwnedArray::zeros((sentence.len()+1, dim, dim));
// transition probability
let p = |t1: usize, t2: usize, t3: usize| -> f64 {
model.get_tag_count_t(t1 as u32,t2 as u32,t3 as u32) as f64 /
model.get_tag_count_b(t1 as u32,t2 as u32) as f64
};
// word probability
let e = |w: &String, t: usize| -> f64 {
model.get_word_tag_count(w,t as u32) as f64 /
model.get_tag_count_u(t as u32) as f64
};
if sentence.len() == 1 {
for t1 in &tags {
pi_curr[(*t1, TAG_STOP)] = p(TAG_START, TAG_START, *t1) * p(TAG_START, *t1, TAG_STOP) * e(&sentence[0], *t1);
}
}
else {
for t1 in &tags {
for t2 in &tags {
pi_curr[(*t1, *t2)] = p(TAG_START, TAG_START, *t1) * p(TAG_START, *t1, *t2) *
e(&sentence[0], *t1) * e(&sentence[1], *t2);
}
}
for i in 2..sentence.len() {
pi_prev.assign(&pi_curr);
pi_curr.assign(&zeros);
for t1 in &tags {
for t2 in &tags {
let mut v : f64 = 0.0;
for t3 in &tags {
let u : f64 = pi_prev[(*t3,*t1)] * p (*t3, *t1, *t2) * e (&sentence[i], *t2);
if u >= v {
v = u;
backtrace[(i,*t1,*t2)] = *t3;
}
}
pi_curr[(*t1,*t2)] = v;
}
}
}
pi_prev.assign(&pi_curr);
pi_curr.assign(&zeros);
for t1 in &tags {
let mut v : f64 = 0.0;
for t2 in &tags {
let u : f64 = pi_prev[(*t2,*t1)] * p (*t2,*t1,TAG_STOP);
if u >= v {
v = u;
backtrace[(sentence.len(), *t1, TAG_STOP)] = *t2;
}
}
pi_curr[(*t1, TAG_STOP)] = v;
}
}
// find the maximum ending tag.
let mut max = (0.0, (0,0));
for t1 in &tags {
let v = pi_curr[(*t1,TAG_STOP)];
if v >= max.0 {
max = (v, (*t1, TAG_STOP));
}
}
// extract the path
let mut path: Vec<u32> = Vec::with_capacity(sentence.len());
let mut tp = max.1;
path.push(tp.0 as u32);
for i in 1..sentence.len() {
let t = backtrace[(sentence.len()+1-i, tp.0, tp.1)];
path.push(t as u32);
tp = (t, tp.0);
}
path.reverse();
path
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment