Skip to content

Instantly share code, notes, and snippets.

@dkohlsdorf
Last active March 13, 2019 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dkohlsdorf/ce47dfe9da8127f3926d to your computer and use it in GitHub Desktop.
Save dkohlsdorf/ce47dfe9da8127f3926d to your computer and use it in GitHub Desktop.
Hidden Markov Model Inference
package processing.inference;
import java.util.ArrayList;
import processing.model.HiddenMarkovModel;
import processing.utils.ProbabilityUtils;
/**
* Inference algorithms used for training and decoding hidden markov models.
* Includes marginals (fwd- bwd) and viterbi as well as probability calculations
* used for baum welch.<br>
*
* [1] Wendy Holmes: Speech Synthesis and Recognition, 2nd ed.<br>
* [2] Thomas Mann: Numerically Stable Hidden Markov Model Implementation<br>
* [3] Holger Wunsch: Der Baum-Welch Algorithmus fur Hidden Markov Models, ein
* Spezialfall des EM-Algorithmus<br>
* [4] Lawrence RabinerL A tutorial on Hidden Markov Models and Selected
* Applications in Speech Recognition<br>
*
* @author Daniel Kohlsdorf
*/
public class Inference {
/**
* Compute the marginal probability using the forward algorithm (sum
* product)
*
* @param observation an observation sequence
* @param model the model
* @return forward probabilities
*/
public static double[][] forward(ArrayList<double[]> observation, HiddenMarkovModel model) {
int num_states = model.getInitials().length;
double alpha[][] = new double[observation.size()][num_states];
// initial probabilities
for (int i = 0; i < num_states; i++) {
alpha[0][i] = ProbabilityUtils.lg(model.getInitials()[i]) + model.getObservations().get(i).logLikelihood(observation.get(0));
}
// inference
for (int t = 1; t < observation.size(); t++) {
// sum-product
for (int i = 0; i < num_states; i++) {
// sum
double sum = ProbabilityUtils.ZERO;
for (int j = 0; j < num_states; j++) {
sum = ProbabilityUtils.sum(sum,
alpha[t - 1][j] + ProbabilityUtils.lg(model.getTransitions()[j][i]));
}
// product
alpha[t][i] = sum + model.getObservations().get(i).logLikelihood(observation.get(t));
}
}
for (int i = 0; i < num_states; i++) {
alpha[observation.size() - 1][i] += ProbabilityUtils.lg(model.getEnd()[i]);
}
return alpha;
}
/**
* Compute the marginal probability using the backward algorithm (sum
* product)
*
* @param observation an observation sequence
* @param model the model
* @return backward probabilities
*/
public static double[][] backward(ArrayList<double[]> observation, HiddenMarkovModel model) {
int num_states = model.getInitials().length;
double beta[][] = new double[observation.size()][num_states];
// initial
for (int i = 0; i < num_states; i++) {
beta[observation.size() - 1][i] = ProbabilityUtils.lg(model.getEnd()[i]);
}
// inference
for (int t = observation.size() - 2; t >= 0; t--) {
for (int i = 0; i < num_states; i++) {
double sum = ProbabilityUtils.ZERO;
for (int j = 0; j < num_states; j++) {
double prob = ProbabilityUtils.lg(model.getTransitions()[i][j]);
prob += model.getObservations().get(j).logLikelihood(observation.get(t + 1));
prob += beta[t + 1][j];
sum = ProbabilityUtils.sum(sum, prob);
}
beta[t][i] = sum;
}
}
for (int i = 0; i < num_states; i++) {
beta[0][i] += ProbabilityUtils.lg(model.getInitials()[i]);
}
return beta;
}
/**
* Compute the probabilities:
*
* P(t: Si AND t+1: Sj)
*
* for a given observation with forward and backward probabilities already
* computed.
*
* @param observation the observation sequence
* @param model the model
* @return state probabilities for each time slice
*/
public static double[][][] state_probabilities_time(ArrayList<double[]> observation,
HiddenMarkovModel model, double forward[][], double backward[][]) {
int num_states = model.getInitials().length;
// Being in state i at time t transitioning to j at t + 1
double zeta[][][] = new double[observation.size()][num_states][num_states];
for (int t = 0; t < observation.size() - 1; t++) {
double norm = ProbabilityUtils.ZERO;
// compute forward[i] * trans[i][j] * obs_j * beta_j
for (int i = 0; i < num_states; i++) {
for (int j = 0; j < num_states; j++) {
zeta[t][i][j] = forward[t][i] + ProbabilityUtils.lg(model.getTransitions()[i][j]);
zeta[t][i][j] += model.getObservations().get(j).logLikelihood(observation.get(t + 1));
zeta[t][i][j] += backward[t + 1][j];
norm = ProbabilityUtils.sum(norm, zeta[t][i][j]);
}
}
// normalize
for (int i = 0; i < num_states; i++) {
for (int j = 0; j < num_states; j++) {
zeta[t][i][j] -= norm;
}
}
}
return zeta;
}
/**
* Compute the state probabilities for a given observation with forward and
* backward probabilities already computed
*
* @param observation the observation sequence
* @param model the model
* @param forward forward probability matrix
* @param backward backward probability matrix
* @return state probabilities for each time slice
*/
public static double[][] state_probabilities(ArrayList<double[]> observation,
HiddenMarkovModel model, double forward[][], double backward[][]) {
int num_states = model.getInitials().length;
// state probabilities
double gamma[][] = new double[observation.size()][num_states];
// alpha * beta / SUM_j alpha_j * beta_j
for (int t = 0; t < observation.size(); t++) {
// SUM alpha_ti beta_ti
double norm = ProbabilityUtils.ZERO;
for (int i = 0; i < num_states; i++) {
gamma[t][i] = forward[t][i] + backward[t][i];
norm = ProbabilityUtils.sum(norm, gamma[t][i]);
}
// normalize
for (int i = 0; i < num_states; i++) {
gamma[t][i] -= norm;
}
}
return gamma;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment