Skip to content

Instantly share code, notes, and snippets.

@raytroop
Last active September 18, 2018 04:20
Show Gist options
  • Save raytroop/36f969462e777c775b3248582ebf136b to your computer and use it in GitHub Desktop.
Save raytroop/36f969462e777c775b3248582ebf136b to your computer and use it in GitHub Desktop.
numpy Computing Probabilities on HMM - vanilla method and forward algorithm based on `BinRoot/TensorFlow-Book`
from itertools import product
import numpy as np
# >1 vanilla method, computing expensive
def vanilla(initial_prob, trans_prob, obs_prob, observations):
it = list((product([0, 1], repeat=len(observations))))
it = np.array(it)
# [[0 0 0 0 0]
# [0 0 0 0 1]
# [0 0 0 1 0]
# ...
# [1 1 0 1 1]
# [1 1 1 0 0]
# [1 1 1 0 1]
# [1 1 1 1 0]
# [1 1 1 1 1]]
prob_o = obs_prob[it, observations]
prob_o = np.prod(prob_o, axis=1)
prob_i = initial_prob[it[:, 1], 0]
for i in range(0, len(observations)-1):
x = it[:, i]
y = it[:, i+1]
prob_i *= trans_prob[x, y]
prob = np.sum(prob_i * prob_o)
return prob
# 2 TODO forward algorithm
def forward(initial_prob, trans_prob, obs_prob, observations):
fwd = initial_prob * obs_prob[:, observations[0]:observations[0]+1]
for obs in observations[1:]:
fwd = np.dot(trans_prob.T, fwd) * obs_prob[:, obs:obs+1]
return np.sum(fwd)
if __name__ == '__main__':
"""
states = ('Rainy', 'Sunny')
observations = ('walk', 'shop', 'clean')
start_probability = {'Rainy': 0.6, 'Sunny': 0.4}
transition_probability = {
'Rainy : {'Rainy': 0.7, 'Sunny': 0.3},
'Sunny': {'Rainy': 0.4, 'Sunny': 0.6},
}
emission_probability = {
'Rainy : {'walk': 0.1, 'shop': 0.4, 'clean': 0.5},
'Sunny': {'walk': 0.6, 'shop': 0.3, 'clean': 0.1},
}
"""
initial_prob = np.array([[0.6],
[0.4]])
trans_prob = np.array([[0.7, 0.3],
[0.4, 0.6]])
obs_prob = np.array([[0.5, 0.4, 0.1],
[0.1, 0.3, 0.6]])
observations = np.array([0, 1, 1, 2, 1])
print('valina: {:.6f}'.format(vanilla(initial_prob, trans_prob, obs_prob, observations)))
print('forward: {:.6f}'.format(forward(initial_prob, trans_prob, obs_prob, observations)))
# valina: 0.004412
# forward: 0.004642
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment