Skip to content

Instantly share code, notes, and snippets.

@jwmerrill
Created September 21, 2014 18:02
Show Gist options
  • Save jwmerrill/ff422bf00593e006c1a4 to your computer and use it in GitHub Desktop.
Save jwmerrill/ff422bf00593e006c1a4 to your computer and use it in GitHub Desktop.
Viterbi algorithm
const states = ["Healthy", "Fever"]
const obsnames = ["normal", "cold", "dizzy"]
const start_probability = [0.6, 0.4]
const transition_probability = [
0.7 0.3
0.4 0.6
]
const emission_probability = [
0.5 0.4 0.1
0.1 0.3 0.6
]
# Returns the index of the first occurance of elt in arr, or 0 if elt
# does not occur in arr. This is linear scan, but it's faster than hash
# tables for small arrays.
function first_index(arr, elt)
for i = 1:length(arr)
if arr[i] == elt
return i
end
end
return 0
end
function viterbi(obs, states, obsnames, start_p, trans_p, emit_p)
nstates = length(states)
nobs = length(obs)
V = Array(Float64, nstates, nobs)
maxindices = Array(Int, nstates, nobs)
# Initialize base cases (t == 1)
for y in 1:nstates
V[y, 1] = start_p[y] * emit_p[y, first_index(obsnames, obs[1])]
maxindices[y, 1] = y
end
# Run Viterbi for t > 1
for t = 2:nobs
for y in 1:nstates
ep = emit_p[y, first_index(obsnames, obs[t])]
maxprob = 0.0
maxindex = 0
for y0 in 1:nstates
prob = V[y0, t - 1]*trans_p[y0, y]
if (prob > maxprob)
maxprob = prob
maxindex = y0
end
end
V[y, t] = ep*maxprob
maxindices[y, t] = maxindex
end
end
maxprob = 0.0
maxindex = 0
for y in 1:nstates
prob = V[y, end]
if (prob > maxprob)
maxprob = prob
maxindex = y
end
end
return (maxprob, path(states, maxindices, maxindex))
end
function path(states, maxindices, lastindex)
nobs = size(maxindices, 2)
out = Array(eltype(states), nobs)
maxindex = lastindex
for i in 1:nobs
out[end + 1 - i] = states[maxindex]
maxindex = maxindices[maxindex, end+1-i]
end
out
end
function example()
# This is a little odd; would ideally have more observations
observations = obsnames
viterbi(observations, states, obsnames, start_probability, transition_probability, emission_probability)
end
function benchmark_example(n)
for i = 1:n
example()
end
end
println(example())
@time benchmark_example(1000000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment