Created
June 13, 2017 14:15
-
-
Save zoq/24f8b2e4826d837d604f9613615763bb to your computer and use it in GitHub Desktop.
patch.diff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/mlpack/methods/ann/layer/layer_traits.hpp b/src/mlpack/methods/ann/layer/layer_traits.hpp | |
index ff4fbf2d3..546dcc9ec 100644 | |
--- a/src/mlpack/methods/ann/layer/layer_traits.hpp | |
+++ b/src/mlpack/methods/ann/layer/layer_traits.hpp | |
@@ -98,7 +98,7 @@ HAS_MEM_FUNC(InputHeight, HasInputHeight); | |
// This gives us a HasRho<T, U> type (where U is a function pointer) we | |
// can use with SFINAE to catch when a type has a Rho() function. | |
-HAS_MEM_FUNC(InputHeight, HasRho); | |
+HAS_MEM_FUNC(Rho, HasRho); | |
} // namespace ann | |
} // namespace mlpack | |
diff --git a/src/mlpack/methods/ann/layer/lstm_impl.hpp b/src/mlpack/methods/ann/layer/lstm_impl.hpp | |
index 8db1f13e7..fad405269 100644 | |
--- a/src/mlpack/methods/ann/layer/lstm_impl.hpp | |
+++ b/src/mlpack/methods/ann/layer/lstm_impl.hpp | |
@@ -75,6 +75,10 @@ template<typename eT> | |
void LSTM<InputDataType, OutputDataType>::Forward( | |
arma::Mat<eT>&& input, arma::Mat<eT>&& output) | |
{ | |
+ // std::cout << "rho: " << rho << std::endl; | |
+ | |
+ // exit(0); | |
+ | |
if (!deterministic) | |
{ | |
cellParameter.push_back(prevCell); | |
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp | |
index 4117b4f1d..0fb9af215 100644 | |
--- a/src/mlpack/methods/ann/rnn.hpp | |
+++ b/src/mlpack/methods/ann/rnn.hpp | |
@@ -197,6 +197,11 @@ class RNN | |
template<typename Archive> | |
void Serialize(Archive& ar, const unsigned int /* version */); | |
+ //! Get the maximum number of steps to backpropagate through time (BPTT). | |
+ size_t Rho() const { return rho; } | |
+ //! Modify the maximum number of steps to backpropagate through time (BPTT). | |
+ size_t& Rho() { return rho; } | |
+ | |
private: | |
// Helper functions. | |
/** | |
@@ -246,6 +251,9 @@ class RNN | |
//! Number of steps to backpropagate through time (BPTT). | |
size_t rho; | |
+ //! Number of steps to backpropagate through time (BPTT) at the previous step. | |
+ size_t prevRho; | |
+ | |
//! Instantiated outputlayer used to evaluate the network. | |
OutputLayerType outputLayer; | |
diff --git a/src/mlpack/methods/ann/rnn_impl.hpp b/src/mlpack/methods/ann/rnn_impl.hpp | |
index 758ac9fbb..0235174c8 100644 | |
--- a/src/mlpack/methods/ann/rnn_impl.hpp | |
+++ b/src/mlpack/methods/ann/rnn_impl.hpp | |
@@ -23,6 +23,7 @@ | |
#include "visitor/gradient_set_visitor.hpp" | |
#include "visitor/gradient_visitor.hpp" | |
#include "visitor/weight_set_visitor.hpp" | |
+#include "visitor/rho_set_visitor.hpp" | |
namespace mlpack { | |
namespace ann /** Artificial Neural Network. */ { | |
@@ -35,6 +36,7 @@ RNN<OutputLayerType, InitializationRuleType>::RNN( | |
OutputLayerType outputLayer, | |
InitializationRuleType initializeRule) : | |
rho(rho), | |
+ prevRho(0), | |
outputLayer(outputLayer), | |
initializeRule(initializeRule), | |
inputSize(0), | |
@@ -176,6 +178,17 @@ template<typename OutputLayerType, typename InitializationRuleType> | |
void RNN<OutputLayerType, InitializationRuleType>::SinglePredict( | |
const arma::mat& predictors, arma::mat& results) | |
{ | |
+ // Update the input length if the backpropagate through time parameter | |
+ // changed. | |
+ if (prevRho != rho) | |
+ { | |
+ for (size_t i = 1; i < network.size(); ++i) | |
+ boost::apply_visitor(RhoSetVisitor(rho), network[i]); | |
+ | |
+ inputSize = predictors.n_elem / rho; | |
+ prevRho = rho; | |
+ } | |
+ | |
for (size_t seqNum = 0; seqNum < rho; ++seqNum) | |
{ | |
currentInput = predictors.rows(seqNum * inputSize, | |
@@ -208,10 +221,16 @@ double RNN<OutputLayerType, InitializationRuleType>::Evaluate( | |
arma::mat target = arma::mat(responses.colptr(i), responses.n_rows, | |
1, false, true); | |
- if (!inputSize) | |
+ // Update the input length and target length if the | |
+ // backpropagate through time parameter changed. | |
+ if (prevRho != rho) | |
{ | |
+ for (size_t i = 1; i < network.size(); ++i) | |
+ boost::apply_visitor(RhoSetVisitor(rho), network[i]); | |
+ | |
inputSize = input.n_elem / rho; | |
targetSize = target.n_elem / rho; | |
+ prevRho = rho; | |
} | |
double performance = 0; | |
diff --git a/src/mlpack/methods/ann/visitor/CMakeLists.txt b/src/mlpack/methods/ann/visitor/CMakeLists.txt | |
index f7b3d6ee2..f63b6da40 100644 | |
--- a/src/mlpack/methods/ann/visitor/CMakeLists.txt | |
+++ b/src/mlpack/methods/ann/visitor/CMakeLists.txt | |
@@ -37,6 +37,8 @@ set(SOURCES | |
reset_visitor_impl.hpp | |
reward_set_visitor.hpp | |
reward_set_visitor_impl.hpp | |
+ rho_set_visitor.hpp | |
+ rho_set_visitor_impl.hpp | |
save_output_parameter_visitor.hpp | |
save_output_parameter_visitor_impl.hpp | |
set_input_height_visitor.hpp | |
diff --git a/src/mlpack/tests/augmented_rnns_tasks_test.cpp b/src/mlpack/tests/augmented_rnns_tasks_test.cpp | |
index d9f063472..3dab6758c 100644 | |
--- a/src/mlpack/tests/augmented_rnns_tasks_test.cpp | |
+++ b/src/mlpack/tests/augmented_rnns_tasks_test.cpp | |
@@ -267,10 +267,7 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest) | |
model.Add<IdentityLayer<> >(); | |
model.Add<Linear<> >(inputSize, 20); | |
- | |
- LayerTypes lstm = new LSTM<>(20, 7, maxRho); | |
- model.Add(lstm); | |
- | |
+ model.Add<LSTM<> >(20, 7, maxRho); | |
model.Add<Linear<> >(7, outputSize); | |
model.Add<SigmoidLayer<> >(); | |
@@ -300,6 +297,11 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest) | |
std::cout << response.t() << std::endl; | |
}*/ | |
+ // Update the sequence length. Actually we only have to do this if the | |
+ // length changed, but for convenience we set the value every time and | |
+ // the model can check if it really changed. | |
+ model.Rho() = predictor.n_elem; | |
+ | |
model.Train(predictor, response, opt); | |
} | |
} | |
@@ -309,6 +311,10 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest) | |
arma::field<arma::colvec> modelOutput(testSize); | |
for (size_t example = 0; example < testSize; ++example) { | |
arma::colvec softOutput; | |
+ | |
+ // Update the sequence length for the prediction part. | |
+ model.Rho() = testPredictor.at(example).n_elem; | |
+ | |
model.Predict( | |
testPredictor.at(example), | |
softOutput); | |
@@ -317,13 +323,19 @@ BOOST_AUTO_TEST_CASE(LSTMBaselineTest) | |
modelOutput.at(example).at(i) = | |
(modelOutput.at(example).at(i)) < 0.5 ? 0 : 1; | |
} | |
- /*std::cout << "Input:\n"; | |
+ | |
+ // std::cout << "prediction:" << modelOutput.at(example).t() << std::endl; | |
+ // std::cout << "target:" << modelOutput.at(example).t() << std::endl; | |
+ std::cout << "Input:\n"; | |
std::cout << testPredictor.at(example).t() << std::endl; | |
std::cout << "Model output:\n"; | |
std::cout << modelOutput.at(example).t() << std::endl; | |
std::cout << "True output:\n"; | |
- std::cout << testResponse.at(example).t() << std::endl;*/ | |
+ std::cout << testResponse.at(example).t() << std::endl; | |
+ | |
+ std::cout << "=======================================\n"; | |
} | |
+ | |
std::cout << "Final score: " | |
<< SequencePrecision<arma::colvec>(testResponse, modelOutput) | |
<< "\n"; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment