Skip to content

Instantly share code, notes, and snippets.

@kris-singh
Created June 15, 2017 18:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kris-singh/5d347fc74475262e958f01ae6764885d to your computer and use it in GitHub Desktop.
Save kris-singh/5d347fc74475262e958f01ae6764885d to your computer and use it in GitHub Desktop.
#ifndef MLPACK_CORE_OPTIMIZERS_CDK_CDK_HPP
#define MLPACK_CORE_OPTIMIZERS_CDK_CDK_HPP
#define MLPACK_CORE_OPTIMIZERS_CDK_CDK_HPP
#include <mlpack/prereqs.hpp>
namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
/*
* The cdk algorithm for training RBM.
* rbmtype is function that implements
* the following functions
* n_functions: number of data points
* gradients: gradient function for calculating the positive and negative gradient at the given point
* gibbs: gibbs sampler for obtaining the samples for given variable after k-steps
*
* @tparam: RBMType: Type of RBM being used
*/
template<typename RBMType>
class CDK
{
public:
/**
* The default constructor for the CD-k aglorithm
*
* @tparam: RBMType: RBM for which we want to train the algorithm
* @param: epoch: Number of training steps
* @param: k: chain length of gibbs sampler
* @param: persistent: PCD-k or CD-k
*/
CDK(RBMType& rbm,
const size_t k=1,
const double stepSize = 0.01,
const size_t maxIterations = 100000,
const bool shuffle = true,
const bool persistent = false)
/**
* Optimize the given function using cd-k. The given
* starting point will be modified to store the finishing point of the
* algorithm, and the final objective value is returned.
*
* @param iterate Starting point (will be modified).
* @return Objective value of the final point.
*/
double Optimize(arma::mat& iterate);
//! Get the instantiated function to be optimized.
const RBMType& RBM() const { return rbm; }
//! Modify the instantiated function.
RBMType& RBM() { return rbm; }
//! Get the step size.
double StepSize() const { return stepSize; }
//! Modify the step size.
double& StepSize() { return stepSize; }
//! Get the maximum number of iterations (0 indicates no limit).
size_t MaxIterations() const { return maxIterations; }
//! Modify the maximum number of iterations (0 indicates no limit).
size_t& MaxIterations() { return maxIterations; }
//! Get the tolerance for termination.
double Tolerance() const { return tolerance; }
//! Modify the tolerance for termination.
double& Tolerance() { return tolerance; }
//! Get whether or not the individual functions are shuffled.
bool Shuffle() const { return shuffle; }
//! Modify whether or not the individual functions are shuffled.
bool& Shuffle() { return shuffle; }
//! Get whether or not the individual functions are shuffled.
bool Persistent() const { return persistent; }
//! Modify whether or not the individual functions are shuffled.
bool& Persistent() { return persistent; }
//! Get whether or not the individual functions are shuffled.
bool NumSteps() const { return k; }
//! Modify whether or not the individual functions are shuffled.
bool& NumSteps() { return k; }
private:
//! The instantiated function.
RBMType& rbm;
//! The step size for each example.
double stepSize;
//! The maximum number of allowed iterations.
size_t maxIterations;
//! The tolerance for termination.
double tolerance;
//! Controls whether or not the individual functions are shuffled when
//! iterating.
bool shuffle;
// Persistent: THe gibbs sampling using persistent state or not
bool persistent;
// k: The size of gibbs sampling chain
bool k;
// negative_sample: The negative sample
arma::mat negative_sample
}
} /*mlpack*/
} /** Artificial Neural Network. */
#endif
#ifndef MLPACK_CORE_OPTIMIZERS_CDK_CDK_IMPL_HPP
#define MLPACK_CORE_OPTIMIZERS_CDK_CDK_IMPL_HPP
#include <mlpack/prereqs.hpp>
namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
template<typename RBMType>
CDK<RBMType>::CDK(
RBMType& rbm,
const size_t k,
const double stepSize,
const size_t maxIterations,
const bool shuffle,
const bool persistence) :
rbm(rbm),
k(k),
stepSize(stepSize),
maxIterations(maxIterations),
shuffle(shuffle),
persistence(persistence)
{
// Nothing to do here
}
template<typename RBMType>CDK<RBMType>::Optimise(arma::mat& iterate)
{
// Find the number of functions to use.
const size_t numFunctions = rbm.NumFunctions();
// This is used only if shuffle is true.
arma::Col<size_t> visitationOrder;
if (shuffle)
{
visitationOrder = arma::shuffle(arma::linspace<arma::Col<size_t>>(0,
(numFunctions - 1), numFunctions));
}
// Now iterate!
arma::mat gradient(iterate.n_rows, iterate.n_cols);
for (size_t i = 1; i != maxIterations; ++i, ++currentFunction)
{
// Is this iteration the start of a sequence?
if ((currentFunction % numFunctions) == 0)
{
if (shuffle) // Determine order of visitation.
visitationOrder = arma::shuffle(visitationOrder);
}
// Evaluate the gradient for this iteration.
if (shuffle)
rbm.Gradient(iterate, neg_samples, visitationOrder[currentFunction], gradient);
else
rbm.Gradient(iterate, neg_samples, currentFunction, gradient);
// Use the update policy to take a step.
iterate += stepSize * gradient;
}
Log::Info << "CDK: maximum iterations (" << maxIterations << ") reached; "
<< "terminating optimization." << std::endl;
}
} // namespace optimization
} // namespace mlpack
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment