Created
August 18, 2016 12:36
-
-
Save zoq/202f18997afa5ec37ed689c95afd0ae5 to your computer and use it in GitHub Desktop.
connect_layer.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* @file connect_layer.hpp | |
* @author Nilay Jain | |
* Definition of the ConnectLayer class. | |
*/ | |
#ifndef MLPACK_METHODS_ANN_LAYER_CONNECT_LAYER_HPP | |
#define MLPACK_METHODS_ANN_LAYER_CONNECT_LAYER_HPP | |
#include <mlpack/core.hpp> | |
#include <mlpack/methods/ann/layer/layer_traits.hpp> | |
#include <mlpack/methods/ann/cnn.hpp> | |
#include <mlpack/methods/ann/network_util.hpp> | |
namespace mlpack { | |
namespace ann /** Artificial Neural Network. */ { | |
// two networks can emerge from a connect layer. | |
template< | |
typename NetworkTypeA, | |
typename NetworkTypeB, | |
typename InputDataType = arma::cube, | |
typename OutputDataType = arma::mat | |
> | |
class ConnectLayer | |
{ | |
public: | |
template<typename NetworkA, typename NetworkB> | |
ConnectLayer(NetworkA networkA, NetworkB networkB): | |
networkA(std::forward<NetworkA>(networkA)), | |
networkB(std::forward<NetworkB>(networkB)), | |
firstRun(true) | |
{ | |
static_assert(std::is_same<typename std::decay<NetworkA>::type, | |
NetworkTypeA>::value, | |
"The type of networkA must be NetworkTypeA."); | |
static_assert(std::is_same<typename std::decay<NetworkB>::type, | |
NetworkTypeB>::value, | |
"The type of networkB must be NetworkTypeB."); | |
networkASize = NetworkSize(networkA); | |
networkBSize = NetworkSize(networkB); | |
weights.set_size(networkASize + networkBSize, 1); | |
} | |
template<typename eT> | |
void Forward(const arma::Cube<eT>& input, arma::mat<eT>& output) | |
{ | |
if (firstRun) | |
{ | |
NetworkWeights(parameter, networkA); | |
NetworkWeights(parameter, networkB, networkASize); | |
firstRun = false; | |
} | |
networkA.Forward(input, networkA.Layers()); | |
networkB.Forward(input, networkB.Layers()); | |
} | |
template<typename eT> | |
void Backward(arma::Cube<eT>&, arma::Cube<eT>& error, arma::Cube<eT>& ) | |
{ | |
networkA.Backward(networkA.error, networkA.Layers()); | |
networkB.Backward(networkB.error, networkB.Layers()); | |
} | |
template<typename eT> | |
void Gradient(const arma::Cube<eT>&, arma::Cube<eT>& delta, arma::Cube<eT>& gradient) | |
{ | |
NetworkGradients(gradient, networkA); | |
NetworkGradients(gradient, networkB, networkASize); | |
networkA.UpdateGradients(networkA.Layers()); | |
networkB.UpdateGradients(networkB.Layers()); | |
} | |
//! Get the weights. | |
OutputDataType const& Weights() const { return weights; } | |
//! Modify the weights. | |
OutputDataType& Weights() { return weights; } | |
//! Get the input parameter. | |
InputDataType const& InputParameter() const { return inputParameter; } | |
//! Modify the input parameter. | |
InputDataType& InputParameter() { return inputParameter; } | |
//! Get the output parameter. | |
OutputDataType const& OutputParameter() const { return outputParameter; } | |
//! Modify the output parameter. | |
OutputDataType& OutputParameter() { return outputParameter; } | |
//! Get the delta. | |
OutputDataType const& Delta() const { return delta; } | |
//! Modify the delta. | |
OutputDataType& Delta() { return delta; } | |
//! Get the gradient. | |
OutputDataType const& Gradient() const { return gradient; } | |
//! Modify the gradient. | |
OutputDataType& Gradient() { return gradient; } | |
private: | |
NetworkTypeA networkA; | |
NetworkTypeB networkB; | |
size_t networkASize; | |
size_t networkBSize; | |
//! Locally-stored run parameter used to initalize the layer once. | |
bool firstRun; | |
//! Locally-stored weight object. | |
OutputDataType weights; | |
//! Locally-stored delta object. | |
OutputDataType delta; | |
//! Locally-stored gradient object. | |
OutputDataType gradient; | |
//! Locally-stored input parameter object. | |
InputDataType inputParameter; | |
//! Locally-stored output parameter object. | |
OutputDataType outputParameter; | |
}; | |
} // namespace ann | |
} // namspace mlpack | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment