Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Created July 20, 2017 01:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/a54f82735cc16bddf899c1732343ed25 to your computer and use it in GitHub Desktop.
Save unixpickle/a54f82735cc16bddf899c1732343ed25 to your computer and use it in GitHub Desktop.
HMM p(z_i | x)
package main
import (
"fmt"
"math/rand"
"time"
"github.com/unixpickle/anyvec"
"github.com/unixpickle/anyvec/anyvec64"
"github.com/unixpickle/approb"
)
// Column is input, row is output.
// That way, vector application works column-wise.
var (
NumStates = 3
NumOutputs = 4
TerminalEmission = NumOutputs - 1
TransProbs = &anyvec.Matrix{
Data: anyvec64.MakeVectorData([]float64{
0.3, 0.6, 0.3,
0.6, 0.2, 0.3,
0.1, 0.2, 0.4,
}),
Rows: 3,
Cols: 3,
}
EmitProbs = &anyvec.Matrix{
Data: anyvec64.MakeVectorData([]float64{
0.3, 0.01, 0.1,
0.2, 0.49, 0.1,
0.5, 0.25, 0.1,
0, 0.25, 0.7,
}),
Rows: 4,
Cols: 3,
}
InitProbs = &anyvec.Matrix{
Data: anyvec64.MakeVectorData([]float64{
0.5, 0.4, 0.1,
}),
Rows: 3,
Cols: 1,
}
)
func main() {
rand.Seed(time.Now().UnixNano())
observations := []int{2, 1, 0, 3}
for idx := 0; idx < 4; idx++ {
fmt.Println("Testing index", idx)
exact := ExactHidden(observations, idx)
corr := approb.Correlation(5000, 0.5, func() float64 {
return float64(SampleHidden(observations)[idx])
}, func() float64 {
return float64(sampleOutcome(exact))
})
if corr < 0.995 {
fmt.Println("FAIL: correlation was", corr, "(expected ~1.0)")
} else {
fmt.Println("PASS")
}
}
}
// SampleHidden samples a hidden sequence given the
// observations.
func SampleHidden(observations []int) []int {
for {
sampledObs, sampledStates := sampleSequence()
if seqsEqual(sampledObs, observations) {
return sampledStates
}
}
}
// ExactHidden computes the exact distribution of the
// indexed hidden state given the observations.
// The distribution is represented as a probability vector.
func ExactHidden(observations []int, index int) anyvec.Vector {
probs := make([]float64, NumStates)
for state := 0; state < NumStates; state++ {
curState := InitProbs
weight := 1.0
for i, obs := range observations {
outProb, statePosterior := probAccumulationStep(curState, obs)
weight *= outProb
if i == index {
weight *= statePosterior.Data.Data().([]float64)[state]
curState = matMul(TransProbs, oneHot(NumStates, state))
} else {
curState = matMul(TransProbs, statePosterior)
}
if weight == 0 {
break
}
}
probs[state] = weight
}
res := anyvec64.MakeVectorData(probs)
res.Scale(1 / anyvec.Sum(res).(float64))
return res
}
func probAccumulationStep(stateIn *anyvec.Matrix, output int) (prob float64,
statePosterior *anyvec.Matrix) {
prob = matMul(EmitProbs, stateIn).Data.Data().([]float64)[output]
if prob == 0 {
return prob, stateIn
}
emitRow := EmitProbs.Data.Slice(output*NumStates, (output+1)*NumStates)
statePosterior = &anyvec.Matrix{
Data: stateIn.Data.Copy(),
Rows: stateIn.Rows,
Cols: stateIn.Cols,
}
statePosterior.Data.Mul(emitRow)
statePosterior.Data.Scale(1 / prob)
return
}
func sampleSequence() (obs, states []int) {
state := sampleOutcome(InitProbs.Data)
for {
stateVec := oneHot(NumStates, state)
out := sampleOutcome(matMul(EmitProbs, stateVec).Data)
obs = append(obs, out)
states = append(states, state)
if out == TerminalEmission {
break
}
state = sampleOutcome(matMul(TransProbs, stateVec).Data)
}
return
}
func sampleOutcome(vec anyvec.Vector) int {
offset := rand.Float64()
for i, comp := range vec.Data().([]float64) {
offset -= comp
if offset < 0 {
return i
}
}
return vec.Len() - 1
}
func matMul(m1, m2 *anyvec.Matrix) *anyvec.Matrix {
rows := m1.Rows
cols := m2.Cols
res := &anyvec.Matrix{
Data: anyvec64.MakeVector(rows * cols),
Rows: rows,
Cols: cols,
}
res.Product(false, false, 1.0, m1, m2, 0.0)
return res
}
func oneHot(length, value int) *anyvec.Matrix {
vec := anyvec64.MakeVector(length)
vec.Slice(value, value+1).AddScalar(1.0)
return &anyvec.Matrix{Data: vec, Rows: length, Cols: 1}
}
func seqsEqual(s1, s2 []int) bool {
if len(s1) != len(s2) {
return false
}
for i, x := range s1 {
if x != s2[i] {
return false
}
}
return true
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment