Skip to content

Instantly share code, notes, and snippets.

@marc-hanheide
Last active August 3, 2016 15:25
Show Gist options
  • Save marc-hanheide/a0727a710e2669b8a71d4f79f3fc2bf8 to your computer and use it in GitHub Desktop.
Save marc-hanheide/a0727a710e2669b8a71d4f79f3fc2bf8 to your computer and use it in GitHub Desktop.
A didactic implementation of HMMs in Python

A didactic HMM implementation in Python

This code is a simple implementation of an HMM including Baum-Welche Training, Forward-Backward Algorithm, and Viterbi decoding for short and discrete obervation sequences. The example implemented here is for a robot to localise when in a lift, stoppping at three floors STATES=['F1, 'F2', 'F3'] and observing a light on each floor OBS = ['L1', 'L2', 'L3'], but that observation is being noisy, i.e. the robot might see the wrong light.

An example pre-defined HMM looks like this:

def example_hmm():
    hmm = {
        'emission_tbl': {('F1', 'L1'): 0.6,
                         ('F1', 'L2'): 0.2,
                         ('F1', 'L3'): 0.2,
                         ('F2', 'L1'): 0.2,
                         ('F2', 'L2'): 0.6,
                         ('F2', 'L3'): 0.2,
                         ('F3', 'L1'): 0.3,
                         ('F3', 'L2'): 0.3,
                         ('F3', 'L3'): 0.4},
        'prior_tbl': {'F1': 0.5,
                      'F2': 0.25,
                      'F3': 0.25},
        'transition_tbl': {('F1', 'F1'): 0.5,
                           ('F1', 'F2'): 0.5,
                           ('F1', 'F3'): 0.0,
                           ('F2', 'F1'): 0.333333,
                           ('F2', 'F2'): 0.333333,
                           ('F2', 'F3'): 0.333334,
                           ('F3', 'F1'): 0.0,
                           ('F3', 'F2'): 0.5,
                           ('F3', 'F3'): 0.5}
                        }
    return hmm

This disallows diretc transitions from F1 to F3 and vice versa.

There are some slides with further details.

from __future__ import division
from collections import Counter, defaultdict
import operator
from pprint import pprint
VERBOSE = False
class HMM:
def __init__(self, STATES, OBS):
self.STATES = STATES
self.OBS = OBS
self.hmm = {}
def forward(self, obs, hmm):
transition_tbl = hmm['transition_tbl'] # a
emission_tbl = hmm['emission_tbl'] # b
prior_tbl = hmm['prior_tbl'] # Pi
alpha_tbl = defaultdict(dict)
for t in xrange(len(obs)):
ob = obs[t]
for s in self.STATES:
if t == 0:
alpha_tbl[t][s] = prior_tbl[s] * emission_tbl[(s, ob)]
else:
alpha_tbl[t][s] = emission_tbl[(s, ob)] * \
sum(alpha_tbl[t-1][ps] *
transition_tbl[(ps, s)]
for ps in self.STATES
)
return alpha_tbl
def backward(self, obs, hmm):
transition_tbl = hmm['transition_tbl'] # a
emission_tbl = hmm['emission_tbl'] # b
beta_tbl = defaultdict(dict)
for t in xrange(len(obs)-1, -1, -1):
for s in self.STATES:
if t == len(obs) - 1:
beta_tbl[t][s] = 1
else:
ob = obs[t+1]
beta_tbl[t][s] = sum(emission_tbl[(ns, ob)] *
beta_tbl[t+1][ns] *
transition_tbl[(s, ns)]
for ns in self.STATES
)
return beta_tbl
def forward_backward(self, obs, hmm, VERBOSE=False):
"""
the backward-forward algorithm
Input: observations and the HMM model parameter
Output: alpha and beta, and production probability
"""
b = hmm['emission_tbl']
pi = hmm['prior_tbl']
# forward
alpha_tbl = self.forward(obs, hmm)
prod_prob_f = sum(alpha_tbl[len(obs) - 1][ns] for ns in self.STATES)
# backward
beta_tbl = self.backward(obs, hmm)
prod_prob_b = sum(pi[s] * b[(s, obs[0])] *
beta_tbl[0][s]
for s in self.STATES)
if VERBOSE:
print "obs: %s" % ','.join(obs)
print "alpha_tbl"
self.print_table(alpha_tbl, xrange(len(obs)), self.STATES)
print
# backward
print "beta_tbl"
self.print_table(beta_tbl, xrange(len(obs)), self.STATES)
print
print 'prod_prob (forward) = %f' % prod_prob_f
print 'prod_prob (backward) = %f' % prod_prob_b
# the prod_prob must be the same for both forward and backward
assert(abs(prod_prob_f - prod_prob_b) < 1e-6)
# print 'prod_prob_b=%f' % prod_prob_b
# the P (S_k | O_i \cdots \O_t]])
return alpha_tbl, beta_tbl, prod_prob_f
def compute_gammas(self, obs, hmm, alpha_tbl, beta_tbl, prod_prob):
"""
compute the gammas required for Baum-Welch Training
Input: observations and the HMM model parameter, as well
as output from forward-backward computations
Output: gammas for transitions and state probability
"""
a = hmm['transition_tbl']
b = hmm['emission_tbl']
gamma = dict()
gamma_f = dict()
for t in xrange(len(obs)):
for s in self.STATES:
if t < len(obs)-1:
ob = obs[t+1]
for ns in self.STATES:
gamma[(t, s, ns)] = (alpha_tbl[t][s] *
a[(s, ns)] *
b[(ns, ob)] *
beta_tbl[t+1][ns]) / prod_prob
gamma_f[(t, s)] = (alpha_tbl[t][s] * beta_tbl[t][s]) / \
prod_prob
return gamma, gamma_f
def baum_welch_step(self, obss, hmm, VERBOSE=False):
"""
do one step of baum-welch tranining with a set of
observations
Input: set of observation sequences and the HMM model parameter
Output: updated HMM data structure
"""
sum_gamma = defaultdict(float)
sum_gamma_f = defaultdict(float)
sum_gamma_f_t1 = defaultdict(float)
sum_gamma_f_obs = defaultdict(float)
sum_prod_prob = 0.0
pin = defaultdict(float)
for obs in obss:
(alpha_tbl, beta_tbl,
prod_prob) = self.forward_backward(obs, hmm, VERBOSE=VERBOSE)
sum_prod_prob += prod_prob / len(obss)
gamma, gamma_f = self.compute_gammas(obs, hmm, alpha_tbl,
beta_tbl, prod_prob)
# pprint(gamma)
if VERBOSE:
print "gamma_f:"
self.print_dict_table(gamma_f,
xrange(len(obs) - 1), self.STATES)
for s in self.STATES:
pin[s] += gamma_f[(0, s)] / len(obss)
sum_gamma_f_t1[s] += sum(gamma_f[(t, s)]
for t in xrange(len(obs)-1)) / \
len(obss)
sum_gamma_f[s] += sum(gamma_f[(t, s)]
for t in xrange(len(obs))) / len(obss)
for ns in self.STATES:
sum_gamma[(s, ns)] += sum(gamma[(t, s, ns)]
for t in xrange(len(obs)-1)) / \
len(obss)
for k in self.OBS:
# some python magic to select the right obs time indices:
Ts = map(lambda (t, o): t,
filter(lambda (t, o): o == k,
zip(xrange(len(obs)), obs)))
sum_gamma_f_obs[(s, k)] += sum(gamma_f[(t, s)]
for t in Ts) / len(obss)
an = dict()
bn = dict()
for s in self.STATES:
for ns in self.STATES:
an[(s, ns)] = sum_gamma[(s, ns)] / sum_gamma_f_t1[s]
for k in self.OBS:
bn[(s, k)] = sum_gamma_f_obs[(s, k)] / sum_gamma_f[s]
hmm['prod_prob'] = sum_prod_prob
hmm_hat = {
'transition_tbl': an,
'emission_tbl': bn,
'prior_tbl': pin
}
sum_prod_prob = 0.0
for obs in obss:
_, _, prod_prob = self.forward_backward(obs, hmm_hat)
sum_prod_prob += prod_prob / len(obss)
hmm_hat['prod_prob'] = sum_prod_prob
self.print_hmm(hmm, "old HMM")
self.print_hmm(hmm_hat, "new HMM")
return hmm_hat
def viterbi(self, obs, hmm):
"""
The Viterbi Algorithm
"""
a, b, pi = hmm['transition_tbl'], hmm['emission_tbl'], hmm['prior_tbl']
delta = defaultdict(dict)
bp = defaultdict(dict)
for t in xrange(len(obs)):
ob = obs[t]
for s in self.STATES:
if t == 0:
delta[t][s] = pi[s] * b[(s, ob)]
else:
bp[t][s], delta[t][s] = max([(ps,
delta[t-1][ps] *
a[(ps, s)] *
b[(s, ob)])
for ps in self.STATES],
key=operator.itemgetter(1))
# get the most probable last state
state, _ = max(delta[len(obs) - 1].items(),
key=operator.itemgetter(1))
t = len(obs) - 1
state_sequence = [state]
# back tracing
while t in bp and state in bp[t]:
state = bp[t][state]
t -= 1
state_sequence.append(state)
return delta, state_sequence[::-1]
def init(self, initial_data, add_const=0.01):
state_sequences = map(lambda row: row[0], initial_data)
state_pairs = [(states[j], states[j+1])
for states in state_sequences
for j in xrange(len(states) - 1)]
state_pairs_freq = Counter(state_pairs)
flat_states_sequences = [state
for states in state_sequences
for state in states[:-1]]
states_freq = Counter(flat_states_sequences)
transition_tbl = dict([((from_state, to_state),
(state_pairs_freq[(from_state, to_state)] +
add_const) /
(states_freq[from_state] + add_const *
len(STATES)))
for from_state in STATES
for to_state in STATES])
# the emission probs
state_obs_pairs = [pair for row in initial_data for pair in zip(*row)]
state_obs_freq = Counter(state_obs_pairs)
flat_states_sequences = [state
for states in state_sequences
for state in states]
states_freq = Counter(flat_states_sequences)
emission_tbl = dict([((state, obs),
(state_obs_freq[(state, obs)] + add_const) /
(states_freq[state] + add_const * len(OBS)))
for state in STATES
for obs in OBS])
print 'emission probability'
h.print_dict_table(emission_tbl, STATES, OBS)
print
# the prior probs
starting_states = map(lambda r: r[0][0], initial_data)
starting_states_freq = Counter(starting_states)
prior_tbl = dict([(s, (starting_states_freq[s] + 1) /
(len(starting_states) + len(STATES)))
for s in STATES])
print 'prior probability'
pprint(prior_tbl)
print
hmm = {
'transition_tbl': transition_tbl,
'emission_tbl': emission_tbl,
'prior_tbl': prior_tbl
}
return hmm
def print_table(self, tbl, rows, cols):
"""
Util function
Print out the prob dist table
"""
print '\t'.join([''] + cols) + '\n'
print '\n'.join(['\t'.join(['%d' % r] +
map(lambda val: '%.5f' % val,
[tbl[r][c] for c in cols]))
for r in rows])
def print_dict_table(self, tbl, rows, cols):
"""
Util function
Print out the dict tables
"""
print '\t'.join([''] + cols) + '\n'
print '\n'.join(['\t'.join(['%s' % r] +
map(lambda val: '%.5f' % val,
[tbl[(r, c)] for c in cols]))
for r in rows])
def print_hmm(self, hmm, title="HMM"):
print '============================================================'
print title
print '------------------------------------------------------------'
print 'initial state probability (Pi):'
pprint(hmm['prior_tbl'])
print '------------------------------------------------------------'
print 'state transition probability (A):'
self.print_dict_table(hmm['transition_tbl'], self.STATES, self.STATES)
print '------------------------------------------------------------'
print 'emmissions probability (B):'
self.print_dict_table(hmm['emission_tbl'], self.STATES, self.OBS)
if 'prod_prob' in hmm:
print '--------------------------------------------------------'
print "production probability = %.6f" % hmm['prod_prob']
print '============================================================'
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
def example_hmm():
hmm = {
'emission_tbl': {('F1', 'L1'): 0.6,
('F1', 'L2'): 0.2,
('F1', 'L3'): 0.2,
('F2', 'L1'): 0.2,
('F2', 'L2'): 0.6,
('F2', 'L3'): 0.2,
('F3', 'L1'): 0.3,
('F3', 'L2'): 0.3,
('F3', 'L3'): 0.4},
'prior_tbl': {'F1': 0.5,
'F2': 0.25,
'F3': 0.25},
'transition_tbl': {('F1', 'F1'): 0.5,
('F1', 'F2'): 0.5,
('F1', 'F3'): 0.0,
('F2', 'F1'): 0.333333,
('F2', 'F2'): 0.333333,
('F2', 'F3'): 0.333334,
('F3', 'F1'): 0.0,
('F3', 'F2'): 0.5,
('F3', 'F3'): 0.5}
}
return hmm
# three lights represented as numbers
L1 = 'L1'
L2 = 'L2'
L3 = 'L3'
# and the three floors
F1 = 'F1'
F2 = 'F2'
F3 = 'F3'
# set of all possible states
STATES = [F1, F2, F3]
# set of all possible observations
OBS = [L1, L2, L3]
# create the HMM object
h = HMM(STATES, OBS)
# given a list of training data, list of (sts, obs) pairs,
# to initialise the HMM model parameters (fully observed!)
initial_data = [
((F3, F2, F1, F1, F2, F3, F2),
(L2, L2, L1, L1, L2, L2, L2)),
((F1, F1, F1, F2, F3, F3, F2),
(L1, L1, L1, L2, L2, L2, L2)),
((F1, F2, F3, F2, F1, F1),
(L1, L2, L2, L2, L2, L1)),
((F1, F2, F3, F3, F2, F1, F2, F1),
(L1, L2, L2, L2, L2, L1, L2, L1))]
# create our example HMM from given parameters
hmm = example_hmm()
# initialise from the initial data
# hmm = h.init(initial_data, add_const=0.01)
h.print_hmm(hmm, 'initial HMM')
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# define two test observations sequences
test_set = [[L1, L3, L3, L2],
[L1, L1, L2, L2, L2, L2, L1]]
# find the production probability for a sequence
# h.forward_backward(test_set[1], hmm, VERBOSE=True)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# define a training set
training_set = [[L1, L2, L3, L3, L1, L1],
[L2, L2, L2, L3, L2, L1],
[L1, L2, L1, L2, L1, L2],
[L2, L1, L1],
[L1, L1, L2, L2]]
# iterate Baum-Welch Training 10 times
for i in xrange(10):
hmm = h.baum_welch_step(training_set, hmm, VERBOSE=False)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
VERBOSE = True
# viterbi decoding of all test sequences
# for obs in test_set:
# delta, seq1 = h.viterbi(obs, hmm)
# print ('Viterbi: for observations "%s", the most '
# 'probable state sequence is: ' %
# (','.join(obs))) + ','.join(seq1)
# if VERBOSE:
# print 'delta:'
# h.print_table(delta, xrange(len(obs)), h.STATES)
# print '----------------'
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment