-
-
Save zoq/595906a62690befce85e3935ccc84f9f to your computer and use it in GitHub Desktop.
layerstring.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TestVisitor : public boost::static_visitor<std::string> | |
{ | |
public: | |
//! Set the gradient to zero for the given module. | |
TestVisitor() | |
{ | |
} | |
//! Set the gradient to zero. | |
template<typename LayerType> | |
std::string operator()(LayerType* layer) const | |
{ | |
return LayerString(layer); | |
} | |
private: | |
//! Set the gradient to zero if the module implements the Gradient() function. | |
template<typename T> | |
typename std::enable_if< | |
std::is_same<T, Linear<> >::value, std::string>::type | |
LayerString(T* layer) const | |
{ | |
return "linear"; | |
} | |
//! Do not set the gradient to zero if the module doesn't implement the | |
//! Gradient() function. | |
template<typename T> | |
typename std::enable_if< | |
!std::is_same<T, Linear<> >::value, std::string>::type | |
LayerString(T* layer) const | |
{ | |
return "not linear"; | |
} | |
}; | |
LayerTypes<> layerA = new Linear<>(20, 30); | |
LayerTypes<> layerB = new LinearNoBias<>(20, 30); | |
std::cout << boost::apply_visitor(TestVisitor(), layerB) << std::endl; |
When you include the code, make sure to include:
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/layer/layer_types.hpp>
before you include anything else. let me know if that solves the compilation issue.
Thanks @zoq, I will keep in mind these include statements. I tried out your suggestion and it worked! There's just this thing that bothers me and that is the size of each function prototype because for around over 10-15 layers it would look really large and maybe slightly confusing to someone else. However, it still seems to be the best way out, so I am going with it.
I finished it a while ago and it is working fine. I'll write a test and push it :)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Unfortunately you have to define every case, here is an example:
make sure you have a case where everything is false: