Created
July 12, 2017 22:33
-
-
Save zoq/253280cf0f494208749906eba71bf841 to your computer and use it in GitHub Desktop.
ffn_gradient_function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Current implementation. | |
template<typename OutputLayerType, typename InitializationRuleType> | |
void FFN<OutputLayerType, InitializationRuleType>::Gradient() | |
{ | |
boost::apply_visitor(GradientVisitor(std::move(currentInput), std::move( | |
boost::apply_visitor(deltaVisitor, network[1]))), network.front()); | |
for (size_t i = 1; i < network.size() - 1; ++i) | |
{ | |
boost::apply_visitor(GradientVisitor(std::move(boost::apply_visitor( | |
outputParameterVisitor, network[i - 1])), std::move( | |
boost::apply_visitor(deltaVisitor, network[i + 1]))), network[i]); | |
} | |
boost::apply_visitor(GradientVisitor(std::move(boost::apply_visitor( | |
outputParameterVisitor, network[network.size() - 2])), std::move(error)), | |
network[network.size() - 1]); | |
} | |
// Slightly modified implementation to only update a single layer. | |
// Suppose we have the index of the layer we like to update, | |
// that could be set before training. | |
template<typename OutputLayerType, typename InitializationRuleType> | |
void FFN<OutputLayerType, InitializationRuleType>::Gradient() | |
{ | |
if (index == 0) // The first layer. | |
{ | |
boost::apply_visitor(GradientVisitor(std::move(currentInput), std::move( | |
boost::apply_visitor(deltaVisitor, network[1]))), network.front()); | |
} | |
else if (index == network.size() - 1) // The last layer. | |
{ | |
boost::apply_visitor(GradientVisitor(std::move(boost::apply_visitor( | |
outputParameterVisitor, network[network.size() - 2])), std::move(error)), | |
network[network.size() - 1]); | |
} | |
else | |
{ | |
boost::apply_visitor(GradientVisitor(std::move(boost::apply_visitor( | |
outputParameterVisitor, network[index - 1])), std::move( | |
boost::apply_visitor(deltaVisitor, network[index + 1]))), network[index]); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment