Skip to content

Instantly share code, notes, and snippets.

@zoq
Created August 18, 2016 12:36
Show Gist options
  • Save zoq/202f18997afa5ec37ed689c95afd0ae5 to your computer and use it in GitHub Desktop.
Save zoq/202f18997afa5ec37ed689c95afd0ae5 to your computer and use it in GitHub Desktop.
connect_layer.hpp
/**
* @file connect_layer.hpp
* @author Nilay Jain
* Definition of the ConnectLayer class.
*/
#ifndef MLPACK_METHODS_ANN_LAYER_CONNECT_LAYER_HPP
#define MLPACK_METHODS_ANN_LAYER_CONNECT_LAYER_HPP
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/layer/layer_traits.hpp>
#include <mlpack/methods/ann/cnn.hpp>
#include <mlpack/methods/ann/network_util.hpp>
namespace mlpack {
namespace ann /** Artificial Neural Network. */ {
// two networks can emerge from a connect layer.
template<
typename NetworkTypeA,
typename NetworkTypeB,
typename InputDataType = arma::cube,
typename OutputDataType = arma::mat
>
class ConnectLayer
{
public:
template<typename NetworkA, typename NetworkB>
ConnectLayer(NetworkA networkA, NetworkB networkB):
networkA(std::forward<NetworkA>(networkA)),
networkB(std::forward<NetworkB>(networkB)),
firstRun(true)
{
static_assert(std::is_same<typename std::decay<NetworkA>::type,
NetworkTypeA>::value,
"The type of networkA must be NetworkTypeA.");
static_assert(std::is_same<typename std::decay<NetworkB>::type,
NetworkTypeB>::value,
"The type of networkB must be NetworkTypeB.");
networkASize = NetworkSize(networkA);
networkBSize = NetworkSize(networkB);
weights.set_size(networkASize + networkBSize, 1);
}
template<typename eT>
void Forward(const arma::Cube<eT>& input, arma::mat<eT>& output)
{
if (firstRun)
{
NetworkWeights(parameter, networkA);
NetworkWeights(parameter, networkB, networkASize);
firstRun = false;
}
networkA.Forward(input, networkA.Layers());
networkB.Forward(input, networkB.Layers());
}
template<typename eT>
void Backward(arma::Cube<eT>&, arma::Cube<eT>& error, arma::Cube<eT>& )
{
networkA.Backward(networkA.error, networkA.Layers());
networkB.Backward(networkB.error, networkB.Layers());
}
template<typename eT>
void Gradient(const arma::Cube<eT>&, arma::Cube<eT>& delta, arma::Cube<eT>& gradient)
{
NetworkGradients(gradient, networkA);
NetworkGradients(gradient, networkB, networkASize);
networkA.UpdateGradients(networkA.Layers());
networkB.UpdateGradients(networkB.Layers());
}
//! Get the weights.
OutputDataType const& Weights() const { return weights; }
//! Modify the weights.
OutputDataType& Weights() { return weights; }
//! Get the input parameter.
InputDataType const& InputParameter() const { return inputParameter; }
//! Modify the input parameter.
InputDataType& InputParameter() { return inputParameter; }
//! Get the output parameter.
OutputDataType const& OutputParameter() const { return outputParameter; }
//! Modify the output parameter.
OutputDataType& OutputParameter() { return outputParameter; }
//! Get the delta.
OutputDataType const& Delta() const { return delta; }
//! Modify the delta.
OutputDataType& Delta() { return delta; }
//! Get the gradient.
OutputDataType const& Gradient() const { return gradient; }
//! Modify the gradient.
OutputDataType& Gradient() { return gradient; }
private:
NetworkTypeA networkA;
NetworkTypeB networkB;
size_t networkASize;
size_t networkBSize;
//! Locally-stored run parameter used to initalize the layer once.
bool firstRun;
//! Locally-stored weight object.
OutputDataType weights;
//! Locally-stored delta object.
OutputDataType delta;
//! Locally-stored gradient object.
OutputDataType gradient;
//! Locally-stored input parameter object.
InputDataType inputParameter;
//! Locally-stored output parameter object.
OutputDataType outputParameter;
};
} // namespace ann
} // namspace mlpack
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment