Created
June 15, 2017 18:54
-
-
Save kris-singh/5d347fc74475262e958f01ae6764885d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef MLPACK_CORE_OPTIMIZERS_CDK_CDK_HPP | |
#define MLPACK_CORE_OPTIMIZERS_CDK_CDK_HPP | |
#define MLPACK_CORE_OPTIMIZERS_CDK_CDK_HPP | |
#include <mlpack/prereqs.hpp> | |
namespace mlpack { | |
namespace ann /** Artificial Neural Network. */ { | |
/* | |
* The cdk algorithm for training RBM. | |
* rbmtype is function that implements | |
* the following functions | |
* n_functions: number of data points | |
* gradients: gradient function for calculating the positive and negative gradient at the given point | |
* gibbs: gibbs sampler for obtaining the samples for given variable after k-steps | |
* | |
* @tparam: RBMType: Type of RBM being used | |
*/ | |
template<typename RBMType> | |
class CDK | |
{ | |
public: | |
/** | |
* The default constructor for the CD-k aglorithm | |
* | |
* @tparam: RBMType: RBM for which we want to train the algorithm | |
* @param: epoch: Number of training steps | |
* @param: k: chain length of gibbs sampler | |
* @param: persistent: PCD-k or CD-k | |
*/ | |
CDK(RBMType& rbm, | |
const size_t k=1, | |
const double stepSize = 0.01, | |
const size_t maxIterations = 100000, | |
const bool shuffle = true, | |
const bool persistent = false) | |
/** | |
* Optimize the given function using cd-k. The given | |
* starting point will be modified to store the finishing point of the | |
* algorithm, and the final objective value is returned. | |
* | |
* @param iterate Starting point (will be modified). | |
* @return Objective value of the final point. | |
*/ | |
double Optimize(arma::mat& iterate); | |
//! Get the instantiated function to be optimized. | |
const RBMType& RBM() const { return rbm; } | |
//! Modify the instantiated function. | |
RBMType& RBM() { return rbm; } | |
//! Get the step size. | |
double StepSize() const { return stepSize; } | |
//! Modify the step size. | |
double& StepSize() { return stepSize; } | |
//! Get the maximum number of iterations (0 indicates no limit). | |
size_t MaxIterations() const { return maxIterations; } | |
//! Modify the maximum number of iterations (0 indicates no limit). | |
size_t& MaxIterations() { return maxIterations; } | |
//! Get the tolerance for termination. | |
double Tolerance() const { return tolerance; } | |
//! Modify the tolerance for termination. | |
double& Tolerance() { return tolerance; } | |
//! Get whether or not the individual functions are shuffled. | |
bool Shuffle() const { return shuffle; } | |
//! Modify whether or not the individual functions are shuffled. | |
bool& Shuffle() { return shuffle; } | |
//! Get whether or not the individual functions are shuffled. | |
bool Persistent() const { return persistent; } | |
//! Modify whether or not the individual functions are shuffled. | |
bool& Persistent() { return persistent; } | |
//! Get whether or not the individual functions are shuffled. | |
bool NumSteps() const { return k; } | |
//! Modify whether or not the individual functions are shuffled. | |
bool& NumSteps() { return k; } | |
private: | |
//! The instantiated function. | |
RBMType& rbm; | |
//! The step size for each example. | |
double stepSize; | |
//! The maximum number of allowed iterations. | |
size_t maxIterations; | |
//! The tolerance for termination. | |
double tolerance; | |
//! Controls whether or not the individual functions are shuffled when | |
//! iterating. | |
bool shuffle; | |
// Persistent: THe gibbs sampling using persistent state or not | |
bool persistent; | |
// k: The size of gibbs sampling chain | |
bool k; | |
// negative_sample: The negative sample | |
arma::mat negative_sample | |
} | |
} /*mlpack*/ | |
} /** Artificial Neural Network. */ | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef MLPACK_CORE_OPTIMIZERS_CDK_CDK_IMPL_HPP | |
#define MLPACK_CORE_OPTIMIZERS_CDK_CDK_IMPL_HPP | |
#include <mlpack/prereqs.hpp> | |
namespace mlpack { | |
namespace ann /** Artificial Neural Network. */ { | |
template<typename RBMType> | |
CDK<RBMType>::CDK( | |
RBMType& rbm, | |
const size_t k, | |
const double stepSize, | |
const size_t maxIterations, | |
const bool shuffle, | |
const bool persistence) : | |
rbm(rbm), | |
k(k), | |
stepSize(stepSize), | |
maxIterations(maxIterations), | |
shuffle(shuffle), | |
persistence(persistence) | |
{ | |
// Nothing to do here | |
} | |
template<typename RBMType>CDK<RBMType>::Optimise(arma::mat& iterate) | |
{ | |
// Find the number of functions to use. | |
const size_t numFunctions = rbm.NumFunctions(); | |
// This is used only if shuffle is true. | |
arma::Col<size_t> visitationOrder; | |
if (shuffle) | |
{ | |
visitationOrder = arma::shuffle(arma::linspace<arma::Col<size_t>>(0, | |
(numFunctions - 1), numFunctions)); | |
} | |
// Now iterate! | |
arma::mat gradient(iterate.n_rows, iterate.n_cols); | |
for (size_t i = 1; i != maxIterations; ++i, ++currentFunction) | |
{ | |
// Is this iteration the start of a sequence? | |
if ((currentFunction % numFunctions) == 0) | |
{ | |
if (shuffle) // Determine order of visitation. | |
visitationOrder = arma::shuffle(visitationOrder); | |
} | |
// Evaluate the gradient for this iteration. | |
if (shuffle) | |
rbm.Gradient(iterate, neg_samples, visitationOrder[currentFunction], gradient); | |
else | |
rbm.Gradient(iterate, neg_samples, currentFunction, gradient); | |
// Use the update policy to take a step. | |
iterate += stepSize * gradient; | |
} | |
Log::Info << "CDK: maximum iterations (" << maxIterations << ") reached; " | |
<< "terminating optimization." << std::endl; | |
} | |
} // namespace optimization | |
} // namespace mlpack | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment