Skip to content

Instantly share code, notes, and snippets.

@dkohlsdorf
Last active March 13, 2019 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dkohlsdorf/2c0ecb71a47670cd6357 to your computer and use it in GitHub Desktop.
Save dkohlsdorf/2c0ecb71a47670cd6357 to your computer and use it in GitHub Desktop.
Hidden Markov Model Training
package processing.inference;
import java.util.ArrayList;
import java.util.Vector;
import processing.model.Gaussian;
import processing.model.GaussianMixture;
import processing.model.HiddenMarkovModel;
import processing.model.ProbabilityDistibution;
import processing.utils.ProbabilityUtils;
/**
* Training and initialization of a Hidden Markov Model<br>
*
* [1] Wendy Holmes: Speech Synthesis and Recognition, 2nd ed.<br>
* [2] Thomas Mann: Numerically Stable Hidden Markov Model Implementation<br>
* [3] Holger Wunsch: Der Baum-Welch Algorithmus fur Hidden Markov Models, ein
* Spezialfall des EM-Algorithmus<br>
* [4] Lawrence RabinerL A tutorial on Hidden Markov Models and Selected
* Applications in Speech Recognition<br>
* [5] David Minnen: Discovering Multivariate Motifs using Subsequence Density
* Estimation and Greedy Mixture Learning<br>
*
* @author Daniel Kohlsdorf
*/
public class Training {
/**
* Train a HMM by running Baum Welch until convergence
*
* @param training_set data to estimate model parameters from
* @param model an initialized model model to estimate
* @param iter maximum number of iterations
*/
public static void BaumWelchConverge(ArrayList<ArrayList<double[]>> training_set, HiddenMarkovModel model, int iter) {
double last = 0;
for (int i = 0; i < iter; i++) {
BaumWelchReestimation(training_set, model);
// compute score for current model as the posterior probability
double score = ProbabilityUtils.ZERO;
for (ArrayList<double[]> seq : training_set) {
double fwd[][] = Inference.forward(seq, model);
int T = fwd.length - 1;
double ll = ProbabilityUtils.ZERO;
for (int s = 0; s < model.getInitials().length; s++) {
ll = ProbabilityUtils.sum(ll, fwd[T][s]);
}
score = ProbabilityUtils.sum(score, ll);
}
// convergence check
if (Math.abs((score / last) - 1) < 1e-6) {
break;
}
System.out.println(" - Baum Welch Iter: " + i + " " + score);
last = score;
}
}
/**
* Initialize a model using global mean and variances and initial
* transitions from data.
*
* @param data A set of time series
* @param model The Hidden Markov Models
*/
public static void initial(ArrayList<ArrayList<double[]>> data, HiddenMarkovModel model) {
int num_states = model.getInitials().length;
// always start in state 1
double init[] = new double[num_states];
double end[] = new double[num_states];
init[0] = 1;
end[num_states - 1] = 1;
// initialize transition matrix
double p_trans = 1.0 / (data.get(0).size() / num_states);
double p_stay = 1.0 - p_trans;
double trans[][] = new double[num_states][num_states];
for (int i = 0; i < num_states; i++) {
trans[i][i] = p_stay;
if (i + 1 < num_states) {
trans[i][i + 1] = p_trans;
}
}
// estimate gaussians from data
ArrayList<ProbabilityDistibution> obs = new ArrayList<ProbabilityDistibution>();
System.out.println(data.size() + " " + data.get(0).size());
int D = data.get(0).get(0).length;
for (int i = 1; i < num_states + 1; i++) {
Gaussian gaussian = new Gaussian(D);
ArrayList<double[]> samples = new ArrayList<>();
for (int j = 0; j < data.size(); j++) {
int b = data.get(j).size() / num_states;
samples.addAll(data.get(j).subList((i - 1) * b, i * b));
}
gaussian.estimate(samples);
obs.add(gaussian);
}
// set model
model.setInitials(init);
model.setEnd(end);
model.setTransitions(trans);
model.setObservations(obs);
}
/**
* Baum Welch Reestimation using all available examples.
*
* [1] Wendy Holmes: Speech Synthesis and Recognition, 2nd ed. Chapter 9 [2]
* Tobias Mann, Numerically Stable Hidden Markov Model Implementation.
*
*/
public static void BaumWelchReestimation(ArrayList<ArrayList<double[]>> training_set, HiddenMarkovModel model) {
// expectation
ArrayList<double[][]> gamma_per_observation = new ArrayList<double[][]>();
ArrayList<double[][][]> zeta_per_observation = new ArrayList<double[][][]>();
for (ArrayList<double[]> observation : training_set) {
double forward[][] = Inference.forward(observation, model);
double backward[][] = Inference.backward(observation, model);
// compute state / time probabilities
double gamma[][] = Inference.state_probabilities(observation, model, forward, backward);
double zeta[][][] = Inference.state_probabilities_time(observation, model, forward, backward);
gamma_per_observation.add(gamma);
zeta_per_observation.add(zeta);
}
int num_states = model.getInitials().length;
int dim = training_set.get(0).get(0).length;
// re estimate transitions aij
// using all examples
double new_trans[][] = new double[num_states][num_states];
for (int i = 0; i < num_states; i++) {
for (int j = 0; j < num_states; j++) {
// #(Si -> Sj) / #(Si)
double num = ProbabilityUtils.ZERO;
double denom = ProbabilityUtils.ZERO;
for (int e = 0; e < zeta_per_observation.size(); e++) {
for (int t = 0; t < zeta_per_observation.get(e).length - 1; t++) {
num = ProbabilityUtils.sum(num, zeta_per_observation.get(e)[t][i][j]);
denom = ProbabilityUtils.sum(denom, gamma_per_observation.get(e)[t][i]);
}
}
new_trans[i][j] = ProbabilityUtils.exp(num - denom);
if (Double.isNaN(new_trans[i][j])) {
System.err.println("NAN ERROR Transitions " + i + " " + j);
System.exit(-1);
}
}
}
// re estimate observation probabilities
// all vectors contribute to each Gaussian.
// How much a vector contributes to a Gaussian is
// given by the probability gamma[t][i]
ArrayList<ProbabilityDistibution> observations = new ArrayList<ProbabilityDistibution>();
for (int i = 0; i < num_states; i++) {
if (model.getObservations().get(i) instanceof GaussianMixture) {
// TODO No mixture training in HMM so far
observations.add(model.getObservations().get(i));
} else {
double mean[] = new double[dim];
double variance[] = new double[dim];
for (int d = 0; d < dim; d++) {
// compute mean
// [1] equation 9.36
double scale = 0;
Vector<Double> weights = new Vector<Double>();
for (int e = 0; e < gamma_per_observation.size(); e++) {
for (int t = 0; t < gamma_per_observation.get(e).length; t++) {
// gamma here a scaling factor between 0 and 1
double weight = ProbabilityUtils.exp(gamma_per_observation.get(e)[t][i]);
scale += weight;
weights.add(weight);
mean[d] += weight * training_set.get(e).get(t)[d];
}
}
if (scale == 0) {
System.err.println("scale error " + weights);
}
mean[d] /= scale;
// compute variance
// [1] equation 9.36
for (int e = 0; e < gamma_per_observation.size(); e++) {
for (int t = 0; t < training_set.get(e).size(); t++) {
double weight = ProbabilityUtils.exp(gamma_per_observation.get(e)[t][i]);
variance[d] += weight * Math.pow(training_set.get(e).get(t)[d] - mean[d], 2);
}
}
variance[d] /= scale;
if (Double.isNaN(mean[d]) || Double.isNaN(variance[d])) {
System.err.println("NAN ERROR Observations " + i + " " + d);
System.err.println(mean[d] + " " + variance[d] + " " + scale);
System.exit(-1);
}
}
observations.add(new Gaussian(mean, variance));
}
}
model.setTransitions(new_trans);
model.setObservations(observations);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment