Skip to content

Instantly share code, notes, and snippets.

@pesterie
Created December 4, 2019 14:58
Show Gist options
  • Save pesterie/06305d8a547350a3686354d509f67142 to your computer and use it in GitHub Desktop.
Save pesterie/06305d8a547350a3686354d509f67142 to your computer and use it in GitHub Desktop.
template<typename Tensor>
inline auto moveaxis(Tensor&& t, std::size_t source, std::size_t destination)
{
using size_type = typename std::decay_t<Tensor>::size_type;
if(source >= t.dimension() || destination >= t.dimension())
{
throw std::runtime_error("Cant't move axis, source has not the same length as destination.");
}
std::vector<size_type> order{};
for(size_type n = 0; n < t.dimension(); ++n)
{
if(n != source)
{
order.push_back(n);
}
}
auto it = std::begin(order) + destination;
order.insert(it, source);
return xt::transpose(std::forward<Tensor>(t), order);
}
template<typename Tensor>
inline auto unfold(Tensor&& t, std::size_t axe)
{
using tensor_type = std::decay_t<Tensor>;
using size_type = typename tensor_type::size_type;
if(axe >= t.dimension())
{
throw std::runtime_error("Axe to unfold through is out of tensor dimension.");
}
auto shape = t.shape();
size_type slice = shape[axe];
std::swap(shape[0], shape[axe]);
std::vector resulting_dimensions(std::begin(shape) + 1, std::end(shape));
size_type leading_dim =
std::accumulate(std::begin(shape) + 1, std::end(shape), size_type(1), [](auto a, auto b) { return a * b; });
using result_of_moveaxis = decltype(moveaxis(std::forward<Tensor>(t), axe, 0));
return xt::reshape_view(std::forward<result_of_moveaxis>(moveaxis(std::forward<Tensor>(t), axe, 0)),
{slice, leading_dim});
}
template<typename Matrix, typename Shape>
inline auto fold(Matrix&& m, std::size_t axe, Shape shape)
{
using size_type = typename Shape::size_type;
std::vector<size_type> new_shape;
new_shape.push_back(shape[axe]);
for(std::size_t i = 0; i < shape.size(); ++i)
{
if(i != axe)
{
new_shape.push_back(shape[i]);
}
}
using result_of_reshaped = decltype(xt::reshape_view(std::forward<Matrix>(m), new_shape));
return moveaxis(std::forward<result_of_reshaped>(xt::reshape_view(std::forward<Matrix>(m), std::move(new_shape))),
0, axe);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment