Created
December 4, 2019 14:58
-
-
Save pesterie/06305d8a547350a3686354d509f67142 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<typename Tensor> | |
inline auto moveaxis(Tensor&& t, std::size_t source, std::size_t destination) | |
{ | |
using size_type = typename std::decay_t<Tensor>::size_type; | |
if(source >= t.dimension() || destination >= t.dimension()) | |
{ | |
throw std::runtime_error("Cant't move axis, source has not the same length as destination."); | |
} | |
std::vector<size_type> order{}; | |
for(size_type n = 0; n < t.dimension(); ++n) | |
{ | |
if(n != source) | |
{ | |
order.push_back(n); | |
} | |
} | |
auto it = std::begin(order) + destination; | |
order.insert(it, source); | |
return xt::transpose(std::forward<Tensor>(t), order); | |
} | |
template<typename Tensor> | |
inline auto unfold(Tensor&& t, std::size_t axe) | |
{ | |
using tensor_type = std::decay_t<Tensor>; | |
using size_type = typename tensor_type::size_type; | |
if(axe >= t.dimension()) | |
{ | |
throw std::runtime_error("Axe to unfold through is out of tensor dimension."); | |
} | |
auto shape = t.shape(); | |
size_type slice = shape[axe]; | |
std::swap(shape[0], shape[axe]); | |
std::vector resulting_dimensions(std::begin(shape) + 1, std::end(shape)); | |
size_type leading_dim = | |
std::accumulate(std::begin(shape) + 1, std::end(shape), size_type(1), [](auto a, auto b) { return a * b; }); | |
using result_of_moveaxis = decltype(moveaxis(std::forward<Tensor>(t), axe, 0)); | |
return xt::reshape_view(std::forward<result_of_moveaxis>(moveaxis(std::forward<Tensor>(t), axe, 0)), | |
{slice, leading_dim}); | |
} | |
template<typename Matrix, typename Shape> | |
inline auto fold(Matrix&& m, std::size_t axe, Shape shape) | |
{ | |
using size_type = typename Shape::size_type; | |
std::vector<size_type> new_shape; | |
new_shape.push_back(shape[axe]); | |
for(std::size_t i = 0; i < shape.size(); ++i) | |
{ | |
if(i != axe) | |
{ | |
new_shape.push_back(shape[i]); | |
} | |
} | |
using result_of_reshaped = decltype(xt::reshape_view(std::forward<Matrix>(m), new_shape)); | |
return moveaxis(std::forward<result_of_reshaped>(xt::reshape_view(std::forward<Matrix>(m), std::move(new_shape))), | |
0, axe); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment