|
from __future__ import division |
|
from collections import Counter, defaultdict |
|
import operator |
|
|
|
from pprint import pprint |
|
|
|
VERBOSE = False |
|
|
|
|
|
class HMM: |
|
|
|
def __init__(self, STATES, OBS): |
|
self.STATES = STATES |
|
self.OBS = OBS |
|
self.hmm = {} |
|
|
|
def forward(self, obs, hmm): |
|
transition_tbl = hmm['transition_tbl'] # a |
|
emission_tbl = hmm['emission_tbl'] # b |
|
prior_tbl = hmm['prior_tbl'] # Pi |
|
|
|
alpha_tbl = defaultdict(dict) |
|
|
|
for t in xrange(len(obs)): |
|
ob = obs[t] |
|
for s in self.STATES: |
|
if t == 0: |
|
alpha_tbl[t][s] = prior_tbl[s] * emission_tbl[(s, ob)] |
|
else: |
|
alpha_tbl[t][s] = emission_tbl[(s, ob)] * \ |
|
sum(alpha_tbl[t-1][ps] * |
|
transition_tbl[(ps, s)] |
|
for ps in self.STATES |
|
) |
|
return alpha_tbl |
|
|
|
def backward(self, obs, hmm): |
|
transition_tbl = hmm['transition_tbl'] # a |
|
emission_tbl = hmm['emission_tbl'] # b |
|
beta_tbl = defaultdict(dict) |
|
|
|
for t in xrange(len(obs)-1, -1, -1): |
|
for s in self.STATES: |
|
if t == len(obs) - 1: |
|
beta_tbl[t][s] = 1 |
|
else: |
|
ob = obs[t+1] |
|
beta_tbl[t][s] = sum(emission_tbl[(ns, ob)] * |
|
beta_tbl[t+1][ns] * |
|
transition_tbl[(s, ns)] |
|
for ns in self.STATES |
|
) |
|
return beta_tbl |
|
|
|
def forward_backward(self, obs, hmm, VERBOSE=False): |
|
""" |
|
the backward-forward algorithm |
|
|
|
Input: observations and the HMM model parameter |
|
Output: alpha and beta, and production probability |
|
""" |
|
|
|
b = hmm['emission_tbl'] |
|
pi = hmm['prior_tbl'] |
|
|
|
# forward |
|
alpha_tbl = self.forward(obs, hmm) |
|
prod_prob_f = sum(alpha_tbl[len(obs) - 1][ns] for ns in self.STATES) |
|
|
|
# backward |
|
beta_tbl = self.backward(obs, hmm) |
|
prod_prob_b = sum(pi[s] * b[(s, obs[0])] * |
|
beta_tbl[0][s] |
|
for s in self.STATES) |
|
|
|
if VERBOSE: |
|
print "obs: %s" % ','.join(obs) |
|
|
|
print "alpha_tbl" |
|
self.print_table(alpha_tbl, xrange(len(obs)), self.STATES) |
|
print |
|
|
|
# backward |
|
print "beta_tbl" |
|
self.print_table(beta_tbl, xrange(len(obs)), self.STATES) |
|
print |
|
|
|
print 'prod_prob (forward) = %f' % prod_prob_f |
|
print 'prod_prob (backward) = %f' % prod_prob_b |
|
|
|
# the prod_prob must be the same for both forward and backward |
|
assert(abs(prod_prob_f - prod_prob_b) < 1e-6) |
|
# print 'prod_prob_b=%f' % prod_prob_b |
|
|
|
# the P (S_k | O_i \cdots \O_t]]) |
|
return alpha_tbl, beta_tbl, prod_prob_f |
|
|
|
def compute_gammas(self, obs, hmm, alpha_tbl, beta_tbl, prod_prob): |
|
""" |
|
compute the gammas required for Baum-Welch Training |
|
|
|
Input: observations and the HMM model parameter, as well |
|
as output from forward-backward computations |
|
Output: gammas for transitions and state probability |
|
""" |
|
a = hmm['transition_tbl'] |
|
b = hmm['emission_tbl'] |
|
|
|
gamma = dict() |
|
gamma_f = dict() |
|
|
|
for t in xrange(len(obs)): |
|
for s in self.STATES: |
|
if t < len(obs)-1: |
|
ob = obs[t+1] |
|
for ns in self.STATES: |
|
gamma[(t, s, ns)] = (alpha_tbl[t][s] * |
|
a[(s, ns)] * |
|
b[(ns, ob)] * |
|
beta_tbl[t+1][ns]) / prod_prob |
|
gamma_f[(t, s)] = (alpha_tbl[t][s] * beta_tbl[t][s]) / \ |
|
prod_prob |
|
|
|
return gamma, gamma_f |
|
|
|
def baum_welch_step(self, obss, hmm, VERBOSE=False): |
|
""" |
|
do one step of baum-welch tranining with a set of |
|
observations |
|
|
|
Input: set of observation sequences and the HMM model parameter |
|
Output: updated HMM data structure |
|
""" |
|
sum_gamma = defaultdict(float) |
|
sum_gamma_f = defaultdict(float) |
|
sum_gamma_f_t1 = defaultdict(float) |
|
sum_gamma_f_obs = defaultdict(float) |
|
sum_prod_prob = 0.0 |
|
pin = defaultdict(float) |
|
|
|
for obs in obss: |
|
|
|
(alpha_tbl, beta_tbl, |
|
prod_prob) = self.forward_backward(obs, hmm, VERBOSE=VERBOSE) |
|
|
|
sum_prod_prob += prod_prob / len(obss) |
|
|
|
gamma, gamma_f = self.compute_gammas(obs, hmm, alpha_tbl, |
|
beta_tbl, prod_prob) |
|
# pprint(gamma) |
|
if VERBOSE: |
|
print "gamma_f:" |
|
self.print_dict_table(gamma_f, |
|
xrange(len(obs) - 1), self.STATES) |
|
|
|
for s in self.STATES: |
|
pin[s] += gamma_f[(0, s)] / len(obss) |
|
sum_gamma_f_t1[s] += sum(gamma_f[(t, s)] |
|
for t in xrange(len(obs)-1)) / \ |
|
len(obss) |
|
sum_gamma_f[s] += sum(gamma_f[(t, s)] |
|
for t in xrange(len(obs))) / len(obss) |
|
for ns in self.STATES: |
|
sum_gamma[(s, ns)] += sum(gamma[(t, s, ns)] |
|
for t in xrange(len(obs)-1)) / \ |
|
len(obss) |
|
for k in self.OBS: |
|
# some python magic to select the right obs time indices: |
|
Ts = map(lambda (t, o): t, |
|
filter(lambda (t, o): o == k, |
|
zip(xrange(len(obs)), obs))) |
|
sum_gamma_f_obs[(s, k)] += sum(gamma_f[(t, s)] |
|
for t in Ts) / len(obss) |
|
an = dict() |
|
bn = dict() |
|
|
|
for s in self.STATES: |
|
for ns in self.STATES: |
|
an[(s, ns)] = sum_gamma[(s, ns)] / sum_gamma_f_t1[s] |
|
for k in self.OBS: |
|
bn[(s, k)] = sum_gamma_f_obs[(s, k)] / sum_gamma_f[s] |
|
|
|
hmm['prod_prob'] = sum_prod_prob |
|
|
|
hmm_hat = { |
|
'transition_tbl': an, |
|
'emission_tbl': bn, |
|
'prior_tbl': pin |
|
} |
|
sum_prod_prob = 0.0 |
|
for obs in obss: |
|
_, _, prod_prob = self.forward_backward(obs, hmm_hat) |
|
sum_prod_prob += prod_prob / len(obss) |
|
hmm_hat['prod_prob'] = sum_prod_prob |
|
|
|
self.print_hmm(hmm, "old HMM") |
|
self.print_hmm(hmm_hat, "new HMM") |
|
|
|
return hmm_hat |
|
|
|
def viterbi(self, obs, hmm): |
|
""" |
|
The Viterbi Algorithm |
|
""" |
|
a, b, pi = hmm['transition_tbl'], hmm['emission_tbl'], hmm['prior_tbl'] |
|
|
|
delta = defaultdict(dict) |
|
bp = defaultdict(dict) |
|
|
|
for t in xrange(len(obs)): |
|
ob = obs[t] |
|
for s in self.STATES: |
|
if t == 0: |
|
delta[t][s] = pi[s] * b[(s, ob)] |
|
else: |
|
bp[t][s], delta[t][s] = max([(ps, |
|
delta[t-1][ps] * |
|
a[(ps, s)] * |
|
b[(s, ob)]) |
|
for ps in self.STATES], |
|
key=operator.itemgetter(1)) |
|
|
|
# get the most probable last state |
|
state, _ = max(delta[len(obs) - 1].items(), |
|
key=operator.itemgetter(1)) |
|
t = len(obs) - 1 |
|
state_sequence = [state] |
|
|
|
# back tracing |
|
while t in bp and state in bp[t]: |
|
state = bp[t][state] |
|
t -= 1 |
|
state_sequence.append(state) |
|
|
|
return delta, state_sequence[::-1] |
|
|
|
def init(self, initial_data, add_const=0.01): |
|
state_sequences = map(lambda row: row[0], initial_data) |
|
|
|
state_pairs = [(states[j], states[j+1]) |
|
for states in state_sequences |
|
for j in xrange(len(states) - 1)] |
|
|
|
state_pairs_freq = Counter(state_pairs) |
|
|
|
flat_states_sequences = [state |
|
for states in state_sequences |
|
for state in states[:-1]] |
|
|
|
states_freq = Counter(flat_states_sequences) |
|
|
|
transition_tbl = dict([((from_state, to_state), |
|
(state_pairs_freq[(from_state, to_state)] + |
|
add_const) / |
|
(states_freq[from_state] + add_const * |
|
len(STATES))) |
|
for from_state in STATES |
|
for to_state in STATES]) |
|
|
|
# the emission probs |
|
state_obs_pairs = [pair for row in initial_data for pair in zip(*row)] |
|
state_obs_freq = Counter(state_obs_pairs) |
|
|
|
flat_states_sequences = [state |
|
for states in state_sequences |
|
for state in states] |
|
|
|
states_freq = Counter(flat_states_sequences) |
|
emission_tbl = dict([((state, obs), |
|
(state_obs_freq[(state, obs)] + add_const) / |
|
(states_freq[state] + add_const * len(OBS))) |
|
for state in STATES |
|
for obs in OBS]) |
|
|
|
print 'emission probability' |
|
h.print_dict_table(emission_tbl, STATES, OBS) |
|
print |
|
|
|
# the prior probs |
|
starting_states = map(lambda r: r[0][0], initial_data) |
|
starting_states_freq = Counter(starting_states) |
|
|
|
prior_tbl = dict([(s, (starting_states_freq[s] + 1) / |
|
(len(starting_states) + len(STATES))) |
|
for s in STATES]) |
|
|
|
print 'prior probability' |
|
pprint(prior_tbl) |
|
print |
|
|
|
hmm = { |
|
'transition_tbl': transition_tbl, |
|
'emission_tbl': emission_tbl, |
|
'prior_tbl': prior_tbl |
|
} |
|
|
|
return hmm |
|
|
|
def print_table(self, tbl, rows, cols): |
|
""" |
|
Util function |
|
Print out the prob dist table |
|
""" |
|
print '\t'.join([''] + cols) + '\n' |
|
print '\n'.join(['\t'.join(['%d' % r] + |
|
map(lambda val: '%.5f' % val, |
|
[tbl[r][c] for c in cols])) |
|
for r in rows]) |
|
|
|
def print_dict_table(self, tbl, rows, cols): |
|
""" |
|
Util function |
|
Print out the dict tables |
|
""" |
|
print '\t'.join([''] + cols) + '\n' |
|
print '\n'.join(['\t'.join(['%s' % r] + |
|
map(lambda val: '%.5f' % val, |
|
[tbl[(r, c)] for c in cols])) |
|
for r in rows]) |
|
|
|
def print_hmm(self, hmm, title="HMM"): |
|
print '============================================================' |
|
print title |
|
print '------------------------------------------------------------' |
|
print 'initial state probability (Pi):' |
|
pprint(hmm['prior_tbl']) |
|
print '------------------------------------------------------------' |
|
|
|
print 'state transition probability (A):' |
|
self.print_dict_table(hmm['transition_tbl'], self.STATES, self.STATES) |
|
print '------------------------------------------------------------' |
|
|
|
print 'emmissions probability (B):' |
|
self.print_dict_table(hmm['emission_tbl'], self.STATES, self.OBS) |
|
|
|
if 'prod_prob' in hmm: |
|
print '--------------------------------------------------------' |
|
print "production probability = %.6f" % hmm['prod_prob'] |
|
|
|
print '============================================================' |
|
|
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
|
|
|
def example_hmm(): |
|
hmm = { |
|
'emission_tbl': {('F1', 'L1'): 0.6, |
|
('F1', 'L2'): 0.2, |
|
('F1', 'L3'): 0.2, |
|
('F2', 'L1'): 0.2, |
|
('F2', 'L2'): 0.6, |
|
('F2', 'L3'): 0.2, |
|
('F3', 'L1'): 0.3, |
|
('F3', 'L2'): 0.3, |
|
('F3', 'L3'): 0.4}, |
|
'prior_tbl': {'F1': 0.5, |
|
'F2': 0.25, |
|
'F3': 0.25}, |
|
'transition_tbl': {('F1', 'F1'): 0.5, |
|
('F1', 'F2'): 0.5, |
|
('F1', 'F3'): 0.0, |
|
('F2', 'F1'): 0.333333, |
|
('F2', 'F2'): 0.333333, |
|
('F2', 'F3'): 0.333334, |
|
('F3', 'F1'): 0.0, |
|
('F3', 'F2'): 0.5, |
|
('F3', 'F3'): 0.5} |
|
} |
|
return hmm |
|
|
|
# three lights represented as numbers |
|
L1 = 'L1' |
|
L2 = 'L2' |
|
L3 = 'L3' |
|
|
|
# and the three floors |
|
F1 = 'F1' |
|
F2 = 'F2' |
|
F3 = 'F3' |
|
|
|
# set of all possible states |
|
STATES = [F1, F2, F3] |
|
|
|
# set of all possible observations |
|
OBS = [L1, L2, L3] |
|
|
|
# create the HMM object |
|
h = HMM(STATES, OBS) |
|
|
|
|
|
# given a list of training data, list of (sts, obs) pairs, |
|
# to initialise the HMM model parameters (fully observed!) |
|
|
|
initial_data = [ |
|
((F3, F2, F1, F1, F2, F3, F2), |
|
(L2, L2, L1, L1, L2, L2, L2)), |
|
((F1, F1, F1, F2, F3, F3, F2), |
|
(L1, L1, L1, L2, L2, L2, L2)), |
|
((F1, F2, F3, F2, F1, F1), |
|
(L1, L2, L2, L2, L2, L1)), |
|
((F1, F2, F3, F3, F2, F1, F2, F1), |
|
(L1, L2, L2, L2, L2, L1, L2, L1))] |
|
|
|
# create our example HMM from given parameters |
|
hmm = example_hmm() |
|
|
|
# initialise from the initial data |
|
# hmm = h.init(initial_data, add_const=0.01) |
|
|
|
h.print_hmm(hmm, 'initial HMM') |
|
|
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
|
# define two test observations sequences |
|
test_set = [[L1, L3, L3, L2], |
|
[L1, L1, L2, L2, L2, L2, L1]] |
|
|
|
# find the production probability for a sequence |
|
# h.forward_backward(test_set[1], hmm, VERBOSE=True) |
|
|
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
|
# define a training set |
|
training_set = [[L1, L2, L3, L3, L1, L1], |
|
[L2, L2, L2, L3, L2, L1], |
|
[L1, L2, L1, L2, L1, L2], |
|
[L2, L1, L1], |
|
[L1, L1, L2, L2]] |
|
|
|
# iterate Baum-Welch Training 10 times |
|
for i in xrange(10): |
|
hmm = h.baum_welch_step(training_set, hmm, VERBOSE=False) |
|
|
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
|
VERBOSE = True |
|
|
|
# viterbi decoding of all test sequences |
|
|
|
# for obs in test_set: |
|
# delta, seq1 = h.viterbi(obs, hmm) |
|
# print ('Viterbi: for observations "%s", the most ' |
|
# 'probable state sequence is: ' % |
|
# (','.join(obs))) + ','.join(seq1) |
|
# if VERBOSE: |
|
# print 'delta:' |
|
# h.print_table(delta, xrange(len(obs)), h.STATES) |
|
# print '----------------' |