Skip to content

Instantly share code, notes, and snippets.

@zjiayao
Created November 26, 2017 16:10
Show Gist options
  • Save zjiayao/306f674a25c2f6d41c49c1db087aaa5f to your computer and use it in GitHub Desktop.
Save zjiayao/306f674a25c2f6d41c49c1db087aaa5f to your computer and use it in GitHub Desktop.
Viterbi Algorithm for Decoding HMM
"""
An implementation of Viterbi Algorithm
for decoding Hideen Markov Models.
(C) Jiayao Zhang 2017
"""
from __future__ import (print_function, division)
import numpy as np
class Viterbi:
"""
Viterbi Class
Used for decoding HMM.
"""
def __init__(self, T, E):
"""
__init__(T, E)
:param T
Transition Matrix
:param E
Emission Matrix
"""
self.T = T
self.E = E
def decode(self, seq, init=None):
"""
decode(seq, init=None)
Used for decoding the HMM specified
when at initiation.
:param seq
Sequence as 1-based integers for decoding.
:param init
Initial probabilities, if not specified,
each state is assigned equal probability.
:returns prob
The probability of most likely decoding sequence.
:returns code
Decoded sequence, 0-based.
"""
states = self.T.shape[0]
length = len(seq)
assert length > 0
if init is None:
init = np.ones(states) / states
# scores and trace
s = np.zeros([states, length])
tr = np.zeros(s.shape, dtype=np.int32)
# fill initial
for i in range(states):
s[i, 0] = self.E[i, seq[0]-1] * init[i]
for j in range(1, length):
for i in range(states):
trans = s[:, j-1] * self.T[:, i]
s[i, j] = np.max(trans) * self.E[i, seq[j]-1]
tr[i, j] = np.argmax(trans)
m, im = np.max(s[:, -1]), np.argmax(s[:, -1])
# backtrace
code = str(im)
for j in range(length-1, 0, -1):
im = tr[im, j]
code += str(im)
code = code[::-1] #''.join(list(map(lambda ch : 'F' if ch == '0' else 'L', code[::-1])))
print("Decoded: ", code, " Probability: ", m)
return m, code
if __name__ == '__main__':
Viterbi(np.array([
[.9, .1],
[.1, .9]
]), np.array([
[.5, .5],
[.75, .25]
])).decode(1 + np.array([1,0,1,0,0,0,1,0,1,1,0]))
# Wikipedia example
# https://en.wikipedia.org/wiki/Viterbi_algorithm#Example
Viterbi(np.array([
[.7, .3],
[.6, .4]
]), np.array([
[.5, .4, .1],
[.1, .3, .6]
])).decode(np.array([1, 2, 3]), np.array([.6, .4]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment