Last active
March 13, 2019 21:03
-
-
Save dkohlsdorf/2c0ecb71a47670cd6357 to your computer and use it in GitHub Desktop.
Hidden Markov Model Training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package processing.inference; | |
import java.util.ArrayList; | |
import java.util.Vector; | |
import processing.model.Gaussian; | |
import processing.model.GaussianMixture; | |
import processing.model.HiddenMarkovModel; | |
import processing.model.ProbabilityDistibution; | |
import processing.utils.ProbabilityUtils; | |
/** | |
* Training and initialization of a Hidden Markov Model<br> | |
* | |
* [1] Wendy Holmes: Speech Synthesis and Recognition, 2nd ed.<br> | |
* [2] Thomas Mann: Numerically Stable Hidden Markov Model Implementation<br> | |
* [3] Holger Wunsch: Der Baum-Welch Algorithmus fur Hidden Markov Models, ein | |
* Spezialfall des EM-Algorithmus<br> | |
* [4] Lawrence RabinerL A tutorial on Hidden Markov Models and Selected | |
* Applications in Speech Recognition<br> | |
* [5] David Minnen: Discovering Multivariate Motifs using Subsequence Density | |
* Estimation and Greedy Mixture Learning<br> | |
* | |
* @author Daniel Kohlsdorf | |
*/ | |
public class Training { | |
/** | |
* Train a HMM by running Baum Welch until convergence | |
* | |
* @param training_set data to estimate model parameters from | |
* @param model an initialized model model to estimate | |
* @param iter maximum number of iterations | |
*/ | |
public static void BaumWelchConverge(ArrayList<ArrayList<double[]>> training_set, HiddenMarkovModel model, int iter) { | |
double last = 0; | |
for (int i = 0; i < iter; i++) { | |
BaumWelchReestimation(training_set, model); | |
// compute score for current model as the posterior probability | |
double score = ProbabilityUtils.ZERO; | |
for (ArrayList<double[]> seq : training_set) { | |
double fwd[][] = Inference.forward(seq, model); | |
int T = fwd.length - 1; | |
double ll = ProbabilityUtils.ZERO; | |
for (int s = 0; s < model.getInitials().length; s++) { | |
ll = ProbabilityUtils.sum(ll, fwd[T][s]); | |
} | |
score = ProbabilityUtils.sum(score, ll); | |
} | |
// convergence check | |
if (Math.abs((score / last) - 1) < 1e-6) { | |
break; | |
} | |
System.out.println(" - Baum Welch Iter: " + i + " " + score); | |
last = score; | |
} | |
} | |
/** | |
* Initialize a model using global mean and variances and initial | |
* transitions from data. | |
* | |
* @param data A set of time series | |
* @param model The Hidden Markov Models | |
*/ | |
public static void initial(ArrayList<ArrayList<double[]>> data, HiddenMarkovModel model) { | |
int num_states = model.getInitials().length; | |
// always start in state 1 | |
double init[] = new double[num_states]; | |
double end[] = new double[num_states]; | |
init[0] = 1; | |
end[num_states - 1] = 1; | |
// initialize transition matrix | |
double p_trans = 1.0 / (data.get(0).size() / num_states); | |
double p_stay = 1.0 - p_trans; | |
double trans[][] = new double[num_states][num_states]; | |
for (int i = 0; i < num_states; i++) { | |
trans[i][i] = p_stay; | |
if (i + 1 < num_states) { | |
trans[i][i + 1] = p_trans; | |
} | |
} | |
// estimate gaussians from data | |
ArrayList<ProbabilityDistibution> obs = new ArrayList<ProbabilityDistibution>(); | |
System.out.println(data.size() + " " + data.get(0).size()); | |
int D = data.get(0).get(0).length; | |
for (int i = 1; i < num_states + 1; i++) { | |
Gaussian gaussian = new Gaussian(D); | |
ArrayList<double[]> samples = new ArrayList<>(); | |
for (int j = 0; j < data.size(); j++) { | |
int b = data.get(j).size() / num_states; | |
samples.addAll(data.get(j).subList((i - 1) * b, i * b)); | |
} | |
gaussian.estimate(samples); | |
obs.add(gaussian); | |
} | |
// set model | |
model.setInitials(init); | |
model.setEnd(end); | |
model.setTransitions(trans); | |
model.setObservations(obs); | |
} | |
/** | |
* Baum Welch Reestimation using all available examples. | |
* | |
* [1] Wendy Holmes: Speech Synthesis and Recognition, 2nd ed. Chapter 9 [2] | |
* Tobias Mann, Numerically Stable Hidden Markov Model Implementation. | |
* | |
*/ | |
public static void BaumWelchReestimation(ArrayList<ArrayList<double[]>> training_set, HiddenMarkovModel model) { | |
// expectation | |
ArrayList<double[][]> gamma_per_observation = new ArrayList<double[][]>(); | |
ArrayList<double[][][]> zeta_per_observation = new ArrayList<double[][][]>(); | |
for (ArrayList<double[]> observation : training_set) { | |
double forward[][] = Inference.forward(observation, model); | |
double backward[][] = Inference.backward(observation, model); | |
// compute state / time probabilities | |
double gamma[][] = Inference.state_probabilities(observation, model, forward, backward); | |
double zeta[][][] = Inference.state_probabilities_time(observation, model, forward, backward); | |
gamma_per_observation.add(gamma); | |
zeta_per_observation.add(zeta); | |
} | |
int num_states = model.getInitials().length; | |
int dim = training_set.get(0).get(0).length; | |
// re estimate transitions aij | |
// using all examples | |
double new_trans[][] = new double[num_states][num_states]; | |
for (int i = 0; i < num_states; i++) { | |
for (int j = 0; j < num_states; j++) { | |
// #(Si -> Sj) / #(Si) | |
double num = ProbabilityUtils.ZERO; | |
double denom = ProbabilityUtils.ZERO; | |
for (int e = 0; e < zeta_per_observation.size(); e++) { | |
for (int t = 0; t < zeta_per_observation.get(e).length - 1; t++) { | |
num = ProbabilityUtils.sum(num, zeta_per_observation.get(e)[t][i][j]); | |
denom = ProbabilityUtils.sum(denom, gamma_per_observation.get(e)[t][i]); | |
} | |
} | |
new_trans[i][j] = ProbabilityUtils.exp(num - denom); | |
if (Double.isNaN(new_trans[i][j])) { | |
System.err.println("NAN ERROR Transitions " + i + " " + j); | |
System.exit(-1); | |
} | |
} | |
} | |
// re estimate observation probabilities | |
// all vectors contribute to each Gaussian. | |
// How much a vector contributes to a Gaussian is | |
// given by the probability gamma[t][i] | |
ArrayList<ProbabilityDistibution> observations = new ArrayList<ProbabilityDistibution>(); | |
for (int i = 0; i < num_states; i++) { | |
if (model.getObservations().get(i) instanceof GaussianMixture) { | |
// TODO No mixture training in HMM so far | |
observations.add(model.getObservations().get(i)); | |
} else { | |
double mean[] = new double[dim]; | |
double variance[] = new double[dim]; | |
for (int d = 0; d < dim; d++) { | |
// compute mean | |
// [1] equation 9.36 | |
double scale = 0; | |
Vector<Double> weights = new Vector<Double>(); | |
for (int e = 0; e < gamma_per_observation.size(); e++) { | |
for (int t = 0; t < gamma_per_observation.get(e).length; t++) { | |
// gamma here a scaling factor between 0 and 1 | |
double weight = ProbabilityUtils.exp(gamma_per_observation.get(e)[t][i]); | |
scale += weight; | |
weights.add(weight); | |
mean[d] += weight * training_set.get(e).get(t)[d]; | |
} | |
} | |
if (scale == 0) { | |
System.err.println("scale error " + weights); | |
} | |
mean[d] /= scale; | |
// compute variance | |
// [1] equation 9.36 | |
for (int e = 0; e < gamma_per_observation.size(); e++) { | |
for (int t = 0; t < training_set.get(e).size(); t++) { | |
double weight = ProbabilityUtils.exp(gamma_per_observation.get(e)[t][i]); | |
variance[d] += weight * Math.pow(training_set.get(e).get(t)[d] - mean[d], 2); | |
} | |
} | |
variance[d] /= scale; | |
if (Double.isNaN(mean[d]) || Double.isNaN(variance[d])) { | |
System.err.println("NAN ERROR Observations " + i + " " + d); | |
System.err.println(mean[d] + " " + variance[d] + " " + scale); | |
System.exit(-1); | |
} | |
} | |
observations.add(new Gaussian(mean, variance)); | |
} | |
} | |
model.setTransitions(new_trans); | |
model.setObservations(observations); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment