Created
July 20, 2017 01:10
-
-
Save unixpickle/a54f82735cc16bddf899c1732343ed25 to your computer and use it in GitHub Desktop.
HMM p(z_i | x)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"math/rand" | |
"time" | |
"github.com/unixpickle/anyvec" | |
"github.com/unixpickle/anyvec/anyvec64" | |
"github.com/unixpickle/approb" | |
) | |
// Column is input, row is output. | |
// That way, vector application works column-wise. | |
var ( | |
NumStates = 3 | |
NumOutputs = 4 | |
TerminalEmission = NumOutputs - 1 | |
TransProbs = &anyvec.Matrix{ | |
Data: anyvec64.MakeVectorData([]float64{ | |
0.3, 0.6, 0.3, | |
0.6, 0.2, 0.3, | |
0.1, 0.2, 0.4, | |
}), | |
Rows: 3, | |
Cols: 3, | |
} | |
EmitProbs = &anyvec.Matrix{ | |
Data: anyvec64.MakeVectorData([]float64{ | |
0.3, 0.01, 0.1, | |
0.2, 0.49, 0.1, | |
0.5, 0.25, 0.1, | |
0, 0.25, 0.7, | |
}), | |
Rows: 4, | |
Cols: 3, | |
} | |
InitProbs = &anyvec.Matrix{ | |
Data: anyvec64.MakeVectorData([]float64{ | |
0.5, 0.4, 0.1, | |
}), | |
Rows: 3, | |
Cols: 1, | |
} | |
) | |
func main() { | |
rand.Seed(time.Now().UnixNano()) | |
observations := []int{2, 1, 0, 3} | |
for idx := 0; idx < 4; idx++ { | |
fmt.Println("Testing index", idx) | |
exact := ExactHidden(observations, idx) | |
corr := approb.Correlation(5000, 0.5, func() float64 { | |
return float64(SampleHidden(observations)[idx]) | |
}, func() float64 { | |
return float64(sampleOutcome(exact)) | |
}) | |
if corr < 0.995 { | |
fmt.Println("FAIL: correlation was", corr, "(expected ~1.0)") | |
} else { | |
fmt.Println("PASS") | |
} | |
} | |
} | |
// SampleHidden samples a hidden sequence given the | |
// observations. | |
func SampleHidden(observations []int) []int { | |
for { | |
sampledObs, sampledStates := sampleSequence() | |
if seqsEqual(sampledObs, observations) { | |
return sampledStates | |
} | |
} | |
} | |
// ExactHidden computes the exact distribution of the | |
// indexed hidden state given the observations. | |
// The distribution is represented as a probability vector. | |
func ExactHidden(observations []int, index int) anyvec.Vector { | |
probs := make([]float64, NumStates) | |
for state := 0; state < NumStates; state++ { | |
curState := InitProbs | |
weight := 1.0 | |
for i, obs := range observations { | |
outProb, statePosterior := probAccumulationStep(curState, obs) | |
weight *= outProb | |
if i == index { | |
weight *= statePosterior.Data.Data().([]float64)[state] | |
curState = matMul(TransProbs, oneHot(NumStates, state)) | |
} else { | |
curState = matMul(TransProbs, statePosterior) | |
} | |
if weight == 0 { | |
break | |
} | |
} | |
probs[state] = weight | |
} | |
res := anyvec64.MakeVectorData(probs) | |
res.Scale(1 / anyvec.Sum(res).(float64)) | |
return res | |
} | |
func probAccumulationStep(stateIn *anyvec.Matrix, output int) (prob float64, | |
statePosterior *anyvec.Matrix) { | |
prob = matMul(EmitProbs, stateIn).Data.Data().([]float64)[output] | |
if prob == 0 { | |
return prob, stateIn | |
} | |
emitRow := EmitProbs.Data.Slice(output*NumStates, (output+1)*NumStates) | |
statePosterior = &anyvec.Matrix{ | |
Data: stateIn.Data.Copy(), | |
Rows: stateIn.Rows, | |
Cols: stateIn.Cols, | |
} | |
statePosterior.Data.Mul(emitRow) | |
statePosterior.Data.Scale(1 / prob) | |
return | |
} | |
func sampleSequence() (obs, states []int) { | |
state := sampleOutcome(InitProbs.Data) | |
for { | |
stateVec := oneHot(NumStates, state) | |
out := sampleOutcome(matMul(EmitProbs, stateVec).Data) | |
obs = append(obs, out) | |
states = append(states, state) | |
if out == TerminalEmission { | |
break | |
} | |
state = sampleOutcome(matMul(TransProbs, stateVec).Data) | |
} | |
return | |
} | |
func sampleOutcome(vec anyvec.Vector) int { | |
offset := rand.Float64() | |
for i, comp := range vec.Data().([]float64) { | |
offset -= comp | |
if offset < 0 { | |
return i | |
} | |
} | |
return vec.Len() - 1 | |
} | |
func matMul(m1, m2 *anyvec.Matrix) *anyvec.Matrix { | |
rows := m1.Rows | |
cols := m2.Cols | |
res := &anyvec.Matrix{ | |
Data: anyvec64.MakeVector(rows * cols), | |
Rows: rows, | |
Cols: cols, | |
} | |
res.Product(false, false, 1.0, m1, m2, 0.0) | |
return res | |
} | |
func oneHot(length, value int) *anyvec.Matrix { | |
vec := anyvec64.MakeVector(length) | |
vec.Slice(value, value+1).AddScalar(1.0) | |
return &anyvec.Matrix{Data: vec, Rows: length, Cols: 1} | |
} | |
func seqsEqual(s1, s2 []int) bool { | |
if len(s1) != len(s2) { | |
return false | |
} | |
for i, x := range s1 { | |
if x != s2[i] { | |
return false | |
} | |
} | |
return true | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment