Skip to content

Instantly share code, notes, and snippets.

@marc-hanheide
Last active August 3, 2016 15:25
Show Gist options
  • Save marc-hanheide/a0727a710e2669b8a71d4f79f3fc2bf8 to your computer and use it in GitHub Desktop.
Save marc-hanheide/a0727a710e2669b8a71d4f79f3fc2bf8 to your computer and use it in GitHub Desktop.
A didactic implementation of HMMs in Python

A didactic HMM implementation in Python

This code is a simple implementation of an HMM including Baum-Welche Training, Forward-Backward Algorithm, and Viterbi decoding for short and discrete obervation sequences. The example implemented here is for a robot to localise when in a lift, stoppping at three floors STATES=['F1, 'F2', 'F3'] and observing a light on each floor OBS = ['L1', 'L2', 'L3'], but that observation is being noisy, i.e. the robot might see the wrong light.

An example pre-defined HMM looks like this:

def example_hmm():
    hmm = {
        'emission_tbl': {('F1', 'L1'): 0.6,
                         ('F1', 'L2'): 0.2,
                         ('F1', 'L3'): 0.2,
                         ('F2', 'L1'): 0.2,
                         ('F2', 'L2'): 0.6,
                         ('F2', 'L3'): 0.2,
                         ('F3', 'L1'): 0.3,
                         ('F3', 'L2'): 0.3,
                         ('F3', 'L3'): 0.4},
        'prior_tbl': {'F1': 0.5,
                      'F2': 0.25,
                      'F3': 0.25},
        'transition_tbl': {('F1', 'F1'): 0.5,
                           ('F1', 'F2'): 0.5,
                           ('F1', 'F3'): 0.0,
                           ('F2', 'F1'): 0.333333,
                           ('F2', 'F2'): 0.333333,
                           ('F2', 'F3'): 0.333334,
                           ('F3', 'F1'): 0.0,
                           ('F3', 'F2'): 0.5,
                           ('F3', 'F3'): 0.5}
                        }
    return hmm

This disallows diretc transitions from F1 to F3 and vice versa.

There are some slides with further details.

from __future__ import division
from collections import Counter, defaultdict
import operator
from pprint import pprint
VERBOSE = False
class HMM:
def __init__(self, STATES, OBS):
self.STATES = STATES
self.OBS = OBS
self.hmm = {}
def forward(self, obs, hmm):
transition_tbl = hmm['transition_tbl'] # a
emission_tbl = hmm['emission_tbl'] # b
prior_tbl = hmm['prior_tbl'] # Pi
alpha_tbl = defaultdict(dict)
for t in xrange(len(obs)):
ob = obs[t]
for s in self.STATES:
if t == 0:
alpha_tbl[t][s] = prior_tbl[s] * emission_tbl[(s, ob)]
else:
alpha_tbl[t][s] = emission_tbl[(s, ob)] * \
sum(alpha_tbl[t-1][ps] *
transition_tbl[(ps, s)]
for ps in self.STATES
)
return alpha_tbl
def backward(self, obs, hmm):
transition_tbl = hmm['transition_tbl'] # a
emission_tbl = hmm['emission_tbl'] # b
beta_tbl = defaultdict(dict)
for t in xrange(len(obs)-1, -1, -1):
for s in self.STATES:
if t == len(obs) - 1:
beta_tbl[t][s] = 1
else:
ob = obs[t+1]
beta_tbl[t][s] = sum(emission_tbl[(ns, ob)] *
beta_tbl[t+1][ns] *
transition_tbl[(s, ns)]
for ns in self.STATES
)
return beta_tbl
def forward_backward(self, obs, hmm, VERBOSE=False):
"""
the backward-forward algorithm
Input: observations and the HMM model parameter
Output: alpha and beta, and production probability
"""
b = hmm['emission_tbl']
pi = hmm['prior_tbl']
# forward
alpha_tbl = self.forward(obs, hmm)
prod_prob_f = sum(alpha_tbl[len(obs) - 1][ns] for ns in self.STATES)
# backward
beta_tbl = self.backward(obs, hmm)
prod_prob_b = sum(pi[s] * b[(s, obs[0])] *
beta_tbl[0][s]
for s in self.STATES)
if VERBOSE:
print "obs: %s" % ','.join(obs)
print "alpha_tbl"
self.print_table(alpha_tbl, xrange(len(obs)), self.STATES)
print
# backward
print "beta_tbl"
self.print_table(beta_tbl, xrange(len(obs)), self.STATES)
print
print 'prod_prob (forward) = %f' % prod_prob_f
print 'prod_prob (backward) = %f' % prod_prob_b
# the prod_prob must be the same for both forward and backward
assert(abs(prod_prob_f - prod_prob_b) < 1e-6)
# print 'prod_prob_b=%f' % prod_prob_b
# the P (S_k | O_i \cdots \O_t]])
return alpha_tbl, beta_tbl, prod_prob_f
def compute_gammas(self, obs, hmm, alpha_tbl, beta_tbl, prod_prob):
"""
compute the gammas required for Baum-Welch Training
Input: observations and the HMM model parameter, as well
as output from forward-backward computations
Output: gammas for transitions and state probability
"""
a = hmm['transition_tbl']
b = hmm['emission_tbl']
gamma = dict()
gamma_f = dict()
for t in xrange(len(obs)):
for s in self.STATES:
if t < len(obs)-1:
ob = obs[t+1]
for ns in self.STATES:
gamma[(t, s, ns)] = (alpha_tbl[t][s] *
a[(s, ns)] *
b[(ns, ob)] *
beta_tbl[t+1][ns]) / prod_prob
gamma_f[(t, s)] = (alpha_tbl[t][s] * beta_tbl[t][s]) / \
prod_prob
return gamma, gamma_f
def baum_welch_step(self, obss, hmm, VERBOSE=False):
"""
do one step of baum-welch tranining with a set of
observations
Input: set of observation sequences and the HMM model parameter
Output: updated HMM data structure
"""
sum_gamma = defaultdict(float)
sum_gamma_f = defaultdict(float)
sum_gamma_f_t1 = defaultdict(float)
sum_gamma_f_obs = defaultdict(float)
sum_prod_prob = 0.0
pin = defaultdict(float)
for obs in obss:
(alpha_tbl, beta_tbl,
prod_prob) = self.forward_backward(obs, hmm, VERBOSE=VERBOSE)
sum_prod_prob += prod_prob / len(obss)
gamma, gamma_f = self.compute_gammas(obs, hmm, alpha_tbl,
beta_tbl, prod_prob)
# pprint(gamma)
if VERBOSE:
print "gamma_f:"
self.print_dict_table(gamma_f,
xrange(len(obs) - 1), self.STATES)
for s in self.STATES:
pin[s] += gamma_f[(0, s)] / len(obss)
sum_gamma_f_t1[s] += sum(gamma_f[(t, s)]
for t in xrange(len(obs)-1)) / \
len(obss)
sum_gamma_f[s] += sum(gamma_f[(t, s)]
for t in xrange(len(obs))) / len(obss)
for ns in self.STATES:
sum_gamma[(s, ns)] += sum(gamma[(t, s, ns)]
for t in xrange(len(obs)-1)) / \
len(obss)
for k in self.OBS:
# some python magic to select the right obs time indices:
Ts = map(lambda (t, o): t,
filter(lambda (t, o): o == k,
zip(xrange(len(obs)), obs)))
sum_gamma_f_obs[(s, k)] += sum(gamma_f[(t, s)]
for t in Ts) / len(obss)
an = dict()
bn = dict()
for s in self.STATES:
for ns in self.STATES:
an[(s, ns)] = sum_gamma[(s, ns)] / sum_gamma_f_t1[s]
for k in self.OBS:
bn[(s, k)] = sum_gamma_f_obs[(s, k)] / sum_gamma_f[s]
hmm['prod_prob'] = sum_prod_prob
hmm_hat = {
'transition_tbl': an,
'emission_tbl': bn,
'prior_tbl': pin
}
sum_prod_prob = 0.0
for obs in obss:
_, _, prod_prob = self.forward_backward(obs, hmm_hat)
sum_prod_prob += prod_prob / len(obss)
hmm_hat['prod_prob'] = sum_prod_prob
self.print_hmm(hmm, "old HMM")
self.print_hmm(hmm_hat, "new HMM")
return hmm_hat
def viterbi(self, obs, hmm):
"""
The Viterbi Algorithm
"""
a, b, pi = hmm['transition_tbl'], hmm['emission_tbl'], hmm['prior_tbl']
delta = defaultdict(dict)
bp = defaultdict(dict)
for t in xrange(len(obs)):
ob = obs[t]
for s in self.STATES:
if t == 0:
delta[t][s] = pi[s] * b[(s, ob)]
else:
bp[t][s], delta[t][s] = max([(ps,
delta[t-1][ps] *
a[(ps, s)] *
b[(s, ob)])
for ps in self.STATES],
key=operator.itemgetter(1))
# get the most probable last state
state, _ = max(delta[len(obs) - 1].items(),
key=operator.itemgetter(1))
t = len(obs) - 1
state_sequence = [state]
# back tracing
while t in bp and state in bp[t]:
state = bp[t][state]
t -= 1
state_sequence.append(state)
return delta, state_sequence[::-1]
def init(self, initial_data, add_const=0.01):
state_sequences = map(lambda row: row[0], initial_data)
state_pairs = [(states[j], states[j+1])
for states in state_sequences
for j in xrange(len(states) - 1)]
state_pairs_freq = Counter(state_pairs)
flat_states_sequences = [state
for states in state_sequences
for state in states[:-1]]
states_freq = Counter(flat_states_sequences)
transition_tbl = dict([((from_state, to_state),
(state_pairs_freq[(from_state, to_state)] +
add_const) /
(states_freq[from_state] + add_const *
len(STATES)))
for from_state in STATES
for to_state in STATES])
# the emission probs
state_obs_pairs = [pair for row in initial_data for pair in zip(*row)]
state_obs_freq = Counter(state_obs_pairs)
flat_states_sequences = [state
for states in state_sequences
for state in states]
states_freq = Counter(flat_states_sequences)
emission_tbl = dict([((state, obs),
(state_obs_freq[(state, obs)] + add_const) /
(states_freq[state] + add_const * len(OBS)))
for state in STATES
for obs in OBS])
print 'emission probability'
h.print_dict_table(emission_tbl, STATES, OBS)
print
# the prior probs
starting_states = map(lambda r: r[0][0], initial_data)
starting_states_freq = Counter(starting_states)
prior_tbl = dict([(s, (starting_states_freq[s] + 1) /
(len(starting_states) + len(STATES)))
for s in STATES])
print 'prior probability'
pprint(prior_tbl)
print
hmm = {
'transition_tbl': transition_tbl,
'emission_tbl': emission_tbl,
'prior_tbl': prior_tbl
}
return hmm
def print_table(self, tbl, rows, cols):
"""
Util function
Print out the prob dist table
"""
print '\t'.join([''] + cols) + '\n'
print '\n'.join(['\t'.join(['%d' % r] +
map(lambda val: '%.5f' % val,
[tbl[r][c] for c in cols]))
for r in rows])
def print_dict_table(self, tbl, rows, cols):
"""
Util function
Print out the dict tables
"""
print '\t'.join([''] + cols) + '\n'
print '\n'.join(['\t'.join(['%s' % r] +
map(lambda val: '%.5f' % val,
[tbl[(r, c)] for c in cols]))
for r in rows])
def print_hmm(self, hmm, title="HMM"):
print '============================================================'
print title
print '------------------------------------------------------------'
print 'initial state probability (Pi):'
pprint(hmm['prior_tbl'])
print '------------------------------------------------------------'
print 'state transition probability (A):'
self.print_dict_table(hmm['transition_tbl'], self.STATES, self.STATES)
print '------------------------------------------------------------'
print 'emmissions probability (B):'
self.print_dict_table(hmm['emission_tbl'], self.STATES, self.OBS)
if 'prod_prob' in hmm:
print '--------------------------------------------------------'
print "production probability = %.6f" % hmm['prod_prob']
print '============================================================'
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
def example_hmm():
hmm = {
'emission_tbl': {('F1', 'L1'): 0.6,
('F1', 'L2'): 0.2,
('F1', 'L3'): 0.2,
('F2', 'L1'): 0.2,
('F2', 'L2'): 0.6,
('F2', 'L3'): 0.2,
('F3', 'L1'): 0.3,
('F3', 'L2'): 0.3,
('F3', 'L3'): 0.4},
'prior_tbl': {'F1': 0.5,
'F2': 0.25,
'F3': 0.25},
'transition_tbl': {('F1', 'F1'): 0.5,
('F1', 'F2'): 0.5,
('F1', 'F3'): 0.0,
('F2', 'F1'): 0.333333,
('F2', 'F2'): 0.333333,
('F2', 'F3'): 0.333334,
('F3', 'F1'): 0.0,
('F3', 'F2'): 0.5,
('F3', 'F3'): 0.5}
}
return hmm
# three lights represented as numbers
L1 = 'L1'
L2 = 'L2'
L3 = 'L3'
# and the three floors
F1 = 'F1'
F2 = 'F2'
F3 = 'F3'
# set of all possible states
STATES = [F1, F2, F3]
# set of all possible observations
OBS = [L1, L2, L3]
# create the HMM object
h = HMM(STATES, OBS)
# given a list of training data, list of (sts, obs) pairs,
# to initialise the HMM model parameters (fully observed!)
initial_data = [
((F3, F2, F1, F1, F2, F3, F2),
(L2, L2, L1, L1, L2, L2, L2)),
((F1, F1, F1, F2, F3, F3, F2),
(L1, L1, L1, L2, L2, L2, L2)),
((F1, F2, F3, F2, F1, F1),
(L1, L2, L2, L2, L2, L1)),
((F1, F2, F3, F3, F2, F1, F2, F1),
(L1, L2, L2, L2, L2, L1, L2, L1))]
# create our example HMM from given parameters
hmm = example_hmm()
# initialise from the initial data
# hmm = h.init(initial_data, add_const=0.01)
h.print_hmm(hmm, 'initial HMM')
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# define two test observations sequences
test_set = [[L1, L3, L3, L2],
[L1, L1, L2, L2, L2, L2, L1]]
# find the production probability for a sequence
# h.forward_backward(test_set[1], hmm, VERBOSE=True)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# define a training set
training_set = [[L1, L2, L3, L3, L1, L1],
[L2, L2, L2, L3, L2, L1],
[L1, L2, L1, L2, L1, L2],
[L2, L1, L1],
[L1, L1, L2, L2]]
# iterate Baum-Welch Training 10 times
for i in xrange(10):
hmm = h.baum_welch_step(training_set, hmm, VERBOSE=False)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
# +++++++++++++++++++++++++++++++++++++++++++++++++++++
VERBOSE = True
# viterbi decoding of all test sequences
# for obs in test_set:
# delta, seq1 = h.viterbi(obs, hmm)
# print ('Viterbi: for observations "%s", the most '
# 'probable state sequence is: ' %
# (','.join(obs))) + ','.join(seq1)
# if VERBOSE:
# print 'delta:'
# h.print_table(delta, xrange(len(obs)), h.STATES)
# print '----------------'
Display the source blob
Display the rendered blob
Raw
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment