Created
January 23, 2020 16:04
-
-
Save pesterie/2801cc64e46225992190f22009fc9c95 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
namespace xt | |
{ | |
template<typename Tensor> | |
inline auto moveaxis(Tensor&& t, std::size_t source, std::size_t destination) | |
{ | |
using shape_type = typename std::decay_t<Tensor>::shape_type; | |
if(source >= t.dimension() || destination >= t.dimension()) | |
{ | |
throw std::runtime_error("Cant't move axis, source has not the same length as destination."); | |
} | |
shape_type perm{}; | |
shape_type iota_{}; | |
if(!(xt::resize_container(perm, t.dimension()) && xt::resize_container(iota_, t.dimension()))) | |
{ | |
throw std::runtime_error("Cant't resize shape type."); | |
} | |
std::iota(std::begin(iota_), std::end(iota_), 0); | |
std::iota(std::begin(perm), std::end(perm), 0); | |
if(source > destination) | |
{ | |
std::copy(std::begin(iota_) + destination, std::begin(iota_) + source, std::begin(perm) + destination + 1); | |
perm.at(destination) = source; | |
return xt::transpose(std::forward<Tensor>(t), perm); | |
} | |
else | |
{ | |
std::copy(std::begin(iota_) + source + 1, std::begin(iota_) + destination + 1, std::begin(perm) + source); | |
perm.at(destination) = source; | |
return xt::transpose(std::forward<Tensor>(t), perm); | |
} | |
} | |
template<typename Tensor> | |
inline auto unfold(Tensor&& t, std::size_t axe) | |
{ | |
using tensor_type = std::decay_t<Tensor>; | |
using size_type = typename tensor_type::size_type; | |
if(axe >= t.dimension()) | |
{ | |
throw std::runtime_error("Unfold : Axe to unfold through is out of tensor dimension."); | |
} | |
auto shape = t.shape(); | |
size_type slice = shape.at(axe); | |
std::swap(shape.at(0), shape.at(axe)); | |
size_type leading_dim = | |
std::accumulate(std::begin(shape) + 1, std::end(shape), size_type(1), [](auto a, auto b) { return a * b; }); | |
using result_of_moveaxis = decltype(moveaxis(std::forward<Tensor>(t), axe, 0)); | |
return xt::reshape_view(std::forward<result_of_moveaxis>(moveaxis(std::forward<Tensor>(t), axe, 0)), | |
{slice, leading_dim}); | |
} | |
// Here | |
template<typename Matrix, typename Shape> | |
inline auto fold(Matrix&& m, std::size_t axe, Shape shape) | |
{ | |
using size_type = typename Shape::size_type; | |
if(axe >= shape.size()) | |
{ | |
throw std::runtime_error("Fold : Folding axe out of shape range."); | |
} | |
Shape new_shape; | |
if(!xt::resize_container(new_shape, shape.size())) | |
{ | |
throw std::runtime_error("Fold : Cant't resize shape type."); | |
} | |
new_shape.at(0) = shape.at(axe); | |
size_type j{1}; | |
for(size_type i = 0; i < shape.size(); ++i) | |
{ | |
if(i != axe) | |
{ | |
new_shape.at(j) = shape.at(i); | |
j++; | |
} | |
} | |
using result_of_reshaped = decltype(xt::reshape_view(std::forward<Matrix>(m), new_shape)); | |
return moveaxis( | |
std::forward<result_of_reshaped>(xt::reshape_view(std::forward<Matrix>(m), std::move(new_shape))), 0, axe); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment