Skip to content

Instantly share code, notes, and snippets.

@zoq
Created June 13, 2017 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zoq/24f8b2e4826d837d604f9613615763bb to your computer and use it in GitHub Desktop.
Save zoq/24f8b2e4826d837d604f9613615763bb to your computer and use it in GitHub Desktop.
patch.diff
diff --git a/src/mlpack/methods/ann/layer/layer_traits.hpp b/src/mlpack/methods/ann/layer/layer_traits.hpp
index ff4fbf2d3..546dcc9ec 100644
--- a/src/mlpack/methods/ann/layer/layer_traits.hpp
+++ b/src/mlpack/methods/ann/layer/layer_traits.hpp
@@ -98,7 +98,7 @@ HAS_MEM_FUNC(InputHeight, HasInputHeight);
// This gives us a HasRho<T, U> type (where U is a function pointer) we
// can use with SFINAE to catch when a type has a Rho() function.
-HAS_MEM_FUNC(InputHeight, HasRho);
+HAS_MEM_FUNC(Rho, HasRho);
} // namespace ann
} // namespace mlpack
diff --git a/src/mlpack/methods/ann/layer/lstm_impl.hpp b/src/mlpack/methods/ann/layer/lstm_impl.hpp
index 8db1f13e7..fad405269 100644
--- a/src/mlpack/methods/ann/layer/lstm_impl.hpp
+++ b/src/mlpack/methods/ann/layer/lstm_impl.hpp
@@ -75,6 +75,10 @@ template<typename eT>
void LSTM<InputDataType, OutputDataType>::Forward(
arma::Mat<eT>&& input, arma::Mat<eT>&& output)
{
+ // std::cout << "rho: " << rho << std::endl;
+
+ // exit(0);
+
if (!deterministic)
{
cellParameter.push_back(prevCell);
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index 4117b4f1d..0fb9af215 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -197,6 +197,11 @@ class RNN
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);
+ //! Get the maximum number of steps to backpropagate through time (BPTT).
+ size_t Rho() const { return rho; }
+ //! Modify the maximum number of steps to backpropagate through time (BPTT).
+ size_t& Rho() { return rho; }
+
private:
// Helper functions.
/**
@@ -246,6 +251,9 @@ class RNN
//! Number of steps to backpropagate through time (BPTT).
size_t rho;
+ //! Number of steps to backpropagate through time (BPTT) at the previous step.
+ size_t prevRho;
+
//! Instantiated outputlayer used to evaluate the network.
OutputLayerType outputLayer;
diff --git a/src/mlpack/methods/ann/rnn_impl.hpp b/src/mlpack/methods/ann/rnn_impl.hpp
index 758ac9fbb..0235174c8 100644
--- a/src/mlpack/methods/ann/rnn_impl.hpp
+++ b/src/mlpack/methods/ann/rnn_impl.hpp
@@ -23,6 +23,7 @@
#include "visitor/gradient_set_visitor.hpp"
#include "visitor/gradient_visitor.hpp"
#include "visitor/weight_set_visitor.hpp"
+#include "visitor/rho_set_visitor.hpp"
namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
@@ -35,6 +36,7 @@ RNN<OutputLayerType, InitializationRuleType>::RNN(
OutputLayerType outputLayer,
InitializationRuleType initializeRule) :
rho(rho),
+ prevRho(0),
outputLayer(outputLayer),
initializeRule(initializeRule),
inputSize(0),
@@ -176,6 +178,17 @@ template<typename OutputLayerType, typename InitializationRuleType>
void RNN<OutputLayerType, InitializationRuleType>::SinglePredict(
const arma::mat& predictors, arma::mat& results)
{
+ // Update the input length if the backpropagate through time parameter
+ // changed.
+ if (prevRho != rho)
+ {
+ for (size_t i = 1; i < network.size(); ++i)
+ boost::apply_visitor(RhoSetVisitor(rho), network[i]);
+
+ inputSize = predictors.n_elem / rho;
+ prevRho = rho;
+ }
+
for (size_t seqNum = 0; seqNum < rho; ++seqNum)
{
currentInput = predictors.rows(seqNum * inputSize,
@@ -208,10 +221,16 @@ double RNN<OutputLayerType, InitializationRuleType>::Evaluate(
arma::mat target = arma::mat(responses.colptr(i), responses.n_rows,
1, false, true);
- if (!inputSize)
+ // Update the input length and target length if the
+ // backpropagate through time parameter changed.
+ if (prevRho != rho)
{
+ for (size_t i = 1; i < network.size(); ++i)
+ boost::apply_visitor(RhoSetVisitor(rho), network[i]);
+
inputSize = input.n_elem / rho;
targetSize = target.n_elem / rho;
+ prevRho = rho;
}
double performance = 0;
diff --git a/src/mlpack/methods/ann/visitor/CMakeLists.txt b/src/mlpack/methods/ann/visitor/CMakeLists.txt
index f7b3d6ee2..f63b6da40 100644
--- a/src/mlpack/methods/ann/visitor/CMakeLists.txt
+++ b/src/mlpack/methods/ann/visitor/CMakeLists.txt
@@ -37,6 +37,8 @@ set(SOURCES
reset_visitor_impl.hpp
reward_set_visitor.hpp
reward_set_visitor_impl.hpp
+ rho_set_visitor.hpp
+ rho_set_visitor_impl.hpp
save_output_parameter_visitor.hpp
save_output_parameter_visitor_impl.hpp
set_input_height_visitor.hpp
diff --git a/src/mlpack/tests/augmented_rnns_tasks_test.cpp b/src/mlpack/tests/augmented_rnns_tasks_test.cpp
index d9f063472..3dab6758c 100644
--- a/src/mlpack/tests/augmented_rnns_tasks_test.cpp
+++ b/src/mlpack/tests/augmented_rnns_tasks_test.cpp
@@ -267,10 +267,7 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest)
model.Add<IdentityLayer<> >();
model.Add<Linear<> >(inputSize, 20);
-
- LayerTypes lstm = new LSTM<>(20, 7, maxRho);
- model.Add(lstm);
-
+ model.Add<LSTM<> >(20, 7, maxRho);
model.Add<Linear<> >(7, outputSize);
model.Add<SigmoidLayer<> >();
@@ -300,6 +297,11 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest)
std::cout << response.t() << std::endl;
}*/
+ // Update the sequence length. Actually we only have to do this if the
+ // length changed, but for convenience we set the value every time and
+ // the model can check if it really changed.
+ model.Rho() = predictor.n_elem;
+
model.Train(predictor, response, opt);
}
}
@@ -309,6 +311,10 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest)
arma::field<arma::colvec> modelOutput(testSize);
for (size_t example = 0; example < testSize; ++example) {
arma::colvec softOutput;
+
+ // Update the sequence length for the prediction part.
+ model.Rho() = testPredictor.at(example).n_elem;
+
model.Predict(
testPredictor.at(example),
softOutput);
@@ -317,13 +323,19 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest)
modelOutput.at(example).at(i) =
(modelOutput.at(example).at(i)) < 0.5 ? 0 : 1;
}
- /*std::cout << "Input:\n";
+
+ // std::cout << "prediction:" << modelOutput.at(example).t() << std::endl;
+ // std::cout << "target:" << modelOutput.at(example).t() << std::endl;
+ std::cout << "Input:\n";
std::cout << testPredictor.at(example).t() << std::endl;
std::cout << "Model output:\n";
std::cout << modelOutput.at(example).t() << std::endl;
std::cout << "True output:\n";
- std::cout << testResponse.at(example).t() << std::endl;*/
+ std::cout << testResponse.at(example).t() << std::endl;
+
+ std::cout << "=======================================\n";
}
+
std::cout << "Final score: "
<< SequencePrecision<arma::colvec>(testResponse, modelOutput)
<< "\n";
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment