Skip to content

Instantly share code, notes, and snippets.

@pesterie
Created January 23, 2020 16:04
Show Gist options
  • Save pesterie/2801cc64e46225992190f22009fc9c95 to your computer and use it in GitHub Desktop.
Save pesterie/2801cc64e46225992190f22009fc9c95 to your computer and use it in GitHub Desktop.
namespace xt
{
template<typename Tensor>
inline auto moveaxis(Tensor&& t, std::size_t source, std::size_t destination)
{
using shape_type = typename std::decay_t<Tensor>::shape_type;
if(source >= t.dimension() || destination >= t.dimension())
{
throw std::runtime_error("Cant't move axis, source has not the same length as destination.");
}
shape_type perm{};
shape_type iota_{};
if(!(xt::resize_container(perm, t.dimension()) && xt::resize_container(iota_, t.dimension())))
{
throw std::runtime_error("Cant't resize shape type.");
}
std::iota(std::begin(iota_), std::end(iota_), 0);
std::iota(std::begin(perm), std::end(perm), 0);
if(source > destination)
{
std::copy(std::begin(iota_) + destination, std::begin(iota_) + source, std::begin(perm) + destination + 1);
perm.at(destination) = source;
return xt::transpose(std::forward<Tensor>(t), perm);
}
else
{
std::copy(std::begin(iota_) + source + 1, std::begin(iota_) + destination + 1, std::begin(perm) + source);
perm.at(destination) = source;
return xt::transpose(std::forward<Tensor>(t), perm);
}
}
template<typename Tensor>
inline auto unfold(Tensor&& t, std::size_t axe)
{
using tensor_type = std::decay_t<Tensor>;
using size_type = typename tensor_type::size_type;
if(axe >= t.dimension())
{
throw std::runtime_error("Unfold : Axe to unfold through is out of tensor dimension.");
}
auto shape = t.shape();
size_type slice = shape.at(axe);
std::swap(shape.at(0), shape.at(axe));
size_type leading_dim =
std::accumulate(std::begin(shape) + 1, std::end(shape), size_type(1), [](auto a, auto b) { return a * b; });
using result_of_moveaxis = decltype(moveaxis(std::forward<Tensor>(t), axe, 0));
return xt::reshape_view(std::forward<result_of_moveaxis>(moveaxis(std::forward<Tensor>(t), axe, 0)),
{slice, leading_dim});
}
// Here
template<typename Matrix, typename Shape>
inline auto fold(Matrix&& m, std::size_t axe, Shape shape)
{
using size_type = typename Shape::size_type;
if(axe >= shape.size())
{
throw std::runtime_error("Fold : Folding axe out of shape range.");
}
Shape new_shape;
if(!xt::resize_container(new_shape, shape.size()))
{
throw std::runtime_error("Fold : Cant't resize shape type.");
}
new_shape.at(0) = shape.at(axe);
size_type j{1};
for(size_type i = 0; i < shape.size(); ++i)
{
if(i != axe)
{
new_shape.at(j) = shape.at(i);
j++;
}
}
using result_of_reshaped = decltype(xt::reshape_view(std::forward<Matrix>(m), new_shape));
return moveaxis(
std::forward<result_of_reshaped>(xt::reshape_view(std::forward<Matrix>(m), std::move(new_shape))), 0, axe);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment