Skip to content

Instantly share code, notes, and snippets.

@rcurtin
Created January 26, 2015 15:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rcurtin/daf960aa6ad545f58402 to your computer and use it in GitHub Desktop.
Save rcurtin/daf960aa6ad545f58402 to your computer and use it in GitHub Desktop.
#include <mlpack/core.hpp>
#include <mlpack/methods/gmm/gmm.hpp>
PARAM_STRING_REQ("input_file", "Input dataset.", "i");
using namespace mlpack;
using namespace mlpack::distribution;
using namespace std;
int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
// Load the dataset.
const string inputFile = CLI::GetParam<string>("input_file");
arma::mat dataset;
data::Load(inputFile, dataset, true);
Timer::Start("estimate");
GaussianDistribution g(dataset.n_rows);
g.Estimate(dataset);
Timer::Stop("estimate");
// Make some random points and get their probability.
arma::mat random;
random.randu(dataset.n_rows, 100000);
Timer::Start("probability_batch");
arma::vec probabilities;
g.Probability(random, probabilities);
Timer::Stop("probability_batch");
Timer::Start("probability_individual");
for (size_t i = 0; i < random.n_cols; ++i)
probabilities[i] = g.Probability(random.unsafe_col(i));
Timer::Stop("probability_individual");
// Generate random observations.
Timer::Start("random");
for (size_t i = 0; i < random.n_cols; ++i)
random.col(i) = g.Random();
Timer::Stop("random");
// Do something kind of like GMM training.
Timer::Start("gmm_training_imitation");
for (size_t i = 0; i < 50; ++i)
{
g.Probability(dataset, probabilities);
arma::mat covariance(dataset.n_rows, dataset.n_rows);
covariance.randu();
covariance = covariance.t() * covariance;
g.Covariance(covariance);
}
Timer::Stop("gmm_training_imitation");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment