Skip to content

Instantly share code, notes, and snippets.

@kris-singh
Last active March 4, 2017 14:45
Show Gist options
  • Save kris-singh/0e950f40c1bbc040e7107a0fbb457171 to your computer and use it in GitHub Desktop.
Save kris-singh/0e950f40c1bbc040e7107a0fbb457171 to your computer and use it in GitHub Desktop.
/**
* @file fanin_visitor_impl.hpp
* @author KrishnaKant Singh
*
* Implementation of the fain layer abstraction.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_VISITOR_FANIN_IMPL_VISITOR_IMPL_HPP
#define MLPACK_METHODS_ANN_VISITOR_FANIN_IMPL_VISITOR_IMPL_HPP
// In case it hasn't been included yet.
#include "fanin_visitor.hpp"
namespace mlpack {
namespace ann {
//! FanIN Visitor visitor class.
template<typename LayerType>
inline size_t FanIn::operator()(LayerType* layer) const
{
return LayerFanIn(layer);
}
template<typename T, typename P>
inline typename std::enable_if<
!HasInputWidth<T, size_t&(T::*)()>::value &&
!HasParametersCheck<T, P&(T::*)()>::value, size_t>::type
FanInVisitor::LayerFanIn(T* /* layer */) const
{
return 0;
}
template<typename T, typename P>
inline typename std::enable_if<
HasInputWidth<T, size_t&(T::*)()>::value &&
!HasParametersCheck<T, P&(T::*)()>::value, size_t>::type
FanInVisitor::LayerFanIn(T* layer) const
{
return layer->InputWidth();
}
template<typename T, typename P>
inline typename std::enable_if<
!HasInputWidth<T, size_t&(T::*)()>::value &&
HasParametersCheck<T, P&(T::*)()>::value, size_t>::type
FanInVisitor::LayerFanIn(T* layer) const
{
layer->Parameters().n_rows();
}
template<typename T, typename P>
inline typename std::enable_if<
HasInputWidth<T, size_t&(T::*)()>::value &&
HasParametersCheck<T, P&(T::*)()>::value, size_t>::type
OutputWidthVisitor::LayerOutputWidth(T* layer) const
{
size_t FanIn = layer->InputWidth();
if (FanIn == 0)
{
FanIn = layer->Parameters.n_rows();
}
return FanIn;
}
} // namespace ann
} // namespace mlpack
#endif
-------------------------------------------------------------------
Xavier Init
------------------------------------------------------------------
template<typename T>
void Initialize(T model)
{
for (LayerTypes& layer : model.Model())
{
fanin = boost::apply_visitor(FanInVisitor, layer);
fanout = boost::apply_visito(FanOutVisitor, layer);
layer.Weight = 2/fanin + 2/fanout ;
}
}
template<typename T>
void Initialize(T layer)
{
fanin = boost::apply_visitor(FanInVisiton, layer);
layer.Weight = 2/fanin;
}
}
@kris-singh
Copy link
Author

Do we first have to check using std_enable if the layer has weights parameters and then set. Because it also not have the weight paramaters.

@zoq
Copy link

zoq commented Mar 4, 2017

Looks good, in FanIn = layer->Parameters.n_rows(); you missed '()' for the Parameters.

@zoq
Copy link

zoq commented Mar 4, 2017

What I would do is to figure out the complete model size with:

size_t weights = 0;
for (size_t i = 0; i < network.size(); ++i)
{
  weights += boost::apply_visitor(weightSizeVisitor, network[i]);
}

parameter.set_size(weights, 1);

now parameter is contiguous memory we can use for the weight initialization

Here comes the initilization step for a single layer:

void Initialize(T layer, arma::mat& parameter, const size_t offset)
{
  size_t layerWeightSize = boost::apply_visitor(weightSizeVisitor, layer);

  fanin = boost::apply_visitor(FanInVisitor, layer);
  fanout = boost::apply_visito(FanOutVisitor, layer);

  arma::mat xaviarWeights= ....

  parameter.submat(offset, 0, offset + layerWeightSize - 1, 0) = arma::vectorise(xaviarWeights);
}

Now it's time to Initialize the model with the weights:

size_t offset = 0;
for (size_t i = 0; i < network.size(); ++i)
{
  Initialize(network[i], parameter, offset); <------------------------------

  offset += boost::apply_visitor(WeightSetVisitor(std::move(parameter),
      offset), network[i]);

  boost::apply_visitor(resetVisitor, network[i]);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment