Last active
March 13, 2019 21:03
-
-
Save dkohlsdorf/ce47dfe9da8127f3926d to your computer and use it in GitHub Desktop.
Hidden Markov Model Inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package processing.inference; | |
import java.util.ArrayList; | |
import processing.model.HiddenMarkovModel; | |
import processing.utils.ProbabilityUtils; | |
/** | |
* Inference algorithms used for training and decoding hidden markov models. | |
* Includes marginals (fwd- bwd) and viterbi as well as probability calculations | |
* used for baum welch.<br> | |
* | |
* [1] Wendy Holmes: Speech Synthesis and Recognition, 2nd ed.<br> | |
* [2] Thomas Mann: Numerically Stable Hidden Markov Model Implementation<br> | |
* [3] Holger Wunsch: Der Baum-Welch Algorithmus fur Hidden Markov Models, ein | |
* Spezialfall des EM-Algorithmus<br> | |
* [4] Lawrence RabinerL A tutorial on Hidden Markov Models and Selected | |
* Applications in Speech Recognition<br> | |
* | |
* @author Daniel Kohlsdorf | |
*/ | |
public class Inference { | |
/** | |
* Compute the marginal probability using the forward algorithm (sum | |
* product) | |
* | |
* @param observation an observation sequence | |
* @param model the model | |
* @return forward probabilities | |
*/ | |
public static double[][] forward(ArrayList<double[]> observation, HiddenMarkovModel model) { | |
int num_states = model.getInitials().length; | |
double alpha[][] = new double[observation.size()][num_states]; | |
// initial probabilities | |
for (int i = 0; i < num_states; i++) { | |
alpha[0][i] = ProbabilityUtils.lg(model.getInitials()[i]) + model.getObservations().get(i).logLikelihood(observation.get(0)); | |
} | |
// inference | |
for (int t = 1; t < observation.size(); t++) { | |
// sum-product | |
for (int i = 0; i < num_states; i++) { | |
// sum | |
double sum = ProbabilityUtils.ZERO; | |
for (int j = 0; j < num_states; j++) { | |
sum = ProbabilityUtils.sum(sum, | |
alpha[t - 1][j] + ProbabilityUtils.lg(model.getTransitions()[j][i])); | |
} | |
// product | |
alpha[t][i] = sum + model.getObservations().get(i).logLikelihood(observation.get(t)); | |
} | |
} | |
for (int i = 0; i < num_states; i++) { | |
alpha[observation.size() - 1][i] += ProbabilityUtils.lg(model.getEnd()[i]); | |
} | |
return alpha; | |
} | |
/** | |
* Compute the marginal probability using the backward algorithm (sum | |
* product) | |
* | |
* @param observation an observation sequence | |
* @param model the model | |
* @return backward probabilities | |
*/ | |
public static double[][] backward(ArrayList<double[]> observation, HiddenMarkovModel model) { | |
int num_states = model.getInitials().length; | |
double beta[][] = new double[observation.size()][num_states]; | |
// initial | |
for (int i = 0; i < num_states; i++) { | |
beta[observation.size() - 1][i] = ProbabilityUtils.lg(model.getEnd()[i]); | |
} | |
// inference | |
for (int t = observation.size() - 2; t >= 0; t--) { | |
for (int i = 0; i < num_states; i++) { | |
double sum = ProbabilityUtils.ZERO; | |
for (int j = 0; j < num_states; j++) { | |
double prob = ProbabilityUtils.lg(model.getTransitions()[i][j]); | |
prob += model.getObservations().get(j).logLikelihood(observation.get(t + 1)); | |
prob += beta[t + 1][j]; | |
sum = ProbabilityUtils.sum(sum, prob); | |
} | |
beta[t][i] = sum; | |
} | |
} | |
for (int i = 0; i < num_states; i++) { | |
beta[0][i] += ProbabilityUtils.lg(model.getInitials()[i]); | |
} | |
return beta; | |
} | |
/** | |
* Compute the probabilities: | |
* | |
* P(t: Si AND t+1: Sj) | |
* | |
* for a given observation with forward and backward probabilities already | |
* computed. | |
* | |
* @param observation the observation sequence | |
* @param model the model | |
* @return state probabilities for each time slice | |
*/ | |
public static double[][][] state_probabilities_time(ArrayList<double[]> observation, | |
HiddenMarkovModel model, double forward[][], double backward[][]) { | |
int num_states = model.getInitials().length; | |
// Being in state i at time t transitioning to j at t + 1 | |
double zeta[][][] = new double[observation.size()][num_states][num_states]; | |
for (int t = 0; t < observation.size() - 1; t++) { | |
double norm = ProbabilityUtils.ZERO; | |
// compute forward[i] * trans[i][j] * obs_j * beta_j | |
for (int i = 0; i < num_states; i++) { | |
for (int j = 0; j < num_states; j++) { | |
zeta[t][i][j] = forward[t][i] + ProbabilityUtils.lg(model.getTransitions()[i][j]); | |
zeta[t][i][j] += model.getObservations().get(j).logLikelihood(observation.get(t + 1)); | |
zeta[t][i][j] += backward[t + 1][j]; | |
norm = ProbabilityUtils.sum(norm, zeta[t][i][j]); | |
} | |
} | |
// normalize | |
for (int i = 0; i < num_states; i++) { | |
for (int j = 0; j < num_states; j++) { | |
zeta[t][i][j] -= norm; | |
} | |
} | |
} | |
return zeta; | |
} | |
/** | |
* Compute the state probabilities for a given observation with forward and | |
* backward probabilities already computed | |
* | |
* @param observation the observation sequence | |
* @param model the model | |
* @param forward forward probability matrix | |
* @param backward backward probability matrix | |
* @return state probabilities for each time slice | |
*/ | |
public static double[][] state_probabilities(ArrayList<double[]> observation, | |
HiddenMarkovModel model, double forward[][], double backward[][]) { | |
int num_states = model.getInitials().length; | |
// state probabilities | |
double gamma[][] = new double[observation.size()][num_states]; | |
// alpha * beta / SUM_j alpha_j * beta_j | |
for (int t = 0; t < observation.size(); t++) { | |
// SUM alpha_ti beta_ti | |
double norm = ProbabilityUtils.ZERO; | |
for (int i = 0; i < num_states; i++) { | |
gamma[t][i] = forward[t][i] + backward[t][i]; | |
norm = ProbabilityUtils.sum(norm, gamma[t][i]); | |
} | |
// normalize | |
for (int i = 0; i < num_states; i++) { | |
gamma[t][i] -= norm; | |
} | |
} | |
return gamma; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment