public
Created

Viterbi algorithm

  • Download Gist
viterbi.clj
Clojure
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
(ns ident.viterbi
(:use [clojure.pprint]))
 
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Example
;; (def initpr (to-array [0.6 0.4]))
;; (def transpr (to-array-2d [[0.7 0.3][0.4 0.6]]))
;; (def emisspr (to-array-2d [[0.1 0.4 0.5][0.6 0.3 0.1]]))
;; (def hmm (make-hmm {:states ["rainy" "sunny"] :obs ["walk" "shop" "clean"] :init-probs initpr :emission-probs emisspr :state-transitions transpr}))
;; optimal state sequence and probability of sequence
;; (viterbi hmm [2 1 0]) -> [(0 0 1) 0.03628]
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
 
(defstruct hmm :n :m :init-probs :emission-probs :state-transitions)
 
(defn make-hmm [{:keys [states, obs, init-probs, emission-probs, state-transitions]}]
(struct-map hmm
:n (count states)
:m (count obs)
:states states
:obs obs
:init-probs init-probs ;; n dim
:emission-probs emission-probs ;;m x n
:state-transitions state-transitions))
 
(defn indexed [s]
(map vector (iterate inc 0) s))
 
(defn argmax [coll]
(loop [s (indexed coll)
max (first s)]
(if (empty? s)
max
(let [[idx elt] (first s)
[max-indx max-elt] max]
(if (> elt max-elt)
(recur (rest s) (first s))
(recur (rest s) max))))))
 
(defn pprint-hmm [hmm]
(println "number of states: " (:n hmm) " number of observations: " (:m hmm))
(print "init probabilities: ") (pprint (:init-probs hmm))
(print "emission probs: " ) (pprint (:emission-probs hmm))
(print "state-transitions: " ) (pprint (:state-transitions hmm)))
 
(defn init-alphas [hmm obs]
(map (fn [x]
(* (aget (:init-probs hmm) x) (aget (:emission-probs hmm) x obs)))
(range (:n hmm))))
 
(defn forward [hmm alphas obs]
(map (fn [j]
(* (reduce (fn [sum i]
(+ sum (* (nth alphas i) (aget (:state-transitions hmm) i j))))
0
(range (:n hmm)))
(aget (:emission-probs hmm) j obs))) (range (:n hmm))))
 
(defn delta-max [hmm deltas obs]
(map (fn [j]
(* (apply max (map (fn [i]
(* (nth deltas i)
(aget (:state-transitions hmm) i j)))
(range (:n hmm))))
(aget (:emission-probs hmm) j obs)))
(range (:n hmm))))
 
(defn backtrack [paths deltas]
(loop [path (reverse paths)
term (first (argmax deltas))
backtrack []]
(if (empty? path)
(reverse (conj backtrack term))
(recur (rest path) (nth (first path) term) (conj backtrack term)))))
 
(defn update-paths [hmm deltas]
(map (fn [j]
(first (argmax (map (fn [i]
(* (nth deltas i)
(aget (:state-transitions hmm) i j)))
(range (:n hmm))))))
(range (:n hmm))))
 
(defn viterbi [hmm observs]
(loop [obs (rest observs)
alphas (init-alphas hmm (first observs))
deltas alphas
paths []]
(if (empty? obs)
[(backtrack paths deltas) (float (reduce + alphas))]
(recur (rest obs)
(forward hmm alphas (first obs))
(delta-max hmm deltas (first obs))
(conj paths (update-paths hmm deltas))))))

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.