Skip to content

Instantly share code, notes, and snippets.

@jf99
Last active April 5, 2024 22:04
Show Gist options
  • Save jf99/5b79b76a8ed721cc6b1a6d24ccb4f6ad to your computer and use it in GitHub Desktop.
Save jf99/5b79b76a8ed721cc6b1a6d24ccb4f6ad to your computer and use it in GitHub Desktop.
A way to concatenate a vector of xt::xarrays
#ifndef XTCONCATENATE_HPP
#define XTCONCATENATE_HPP
#include <cstddef>
#include <stdexcept>
#include <vector>
#include <xtensor/xarray.hpp>
template <typename S>
struct ShapeIndex
{
const S& shape;
std::vector<size_t> index;
explicit ShapeIndex(const S& shape);
bool increment();
};
template <typename S>
ShapeIndex<S>::ShapeIndex(const S& shape)
: shape(shape), index(shape.size(), 0)
{}
template <typename S>
bool ShapeIndex<S>::increment()
{
for(int dim = shape.size() -1; dim >= 0; --dim) {
if(index[dim] >= shape[dim] -1)
continue;
++index[dim];
for(size_t lsd = dim +1; lsd < index.size(); ++lsd) // lsd = less significant dimension
index[lsd] = 0;
return true;
}
return false;
}
template <typename T>
xt::xarray<T> concatenate(const std::vector<xt::xarray<T>>& v, const size_t axis = 0)
{
if(v.empty())
return {};
auto resShape = v.front().shape();
if(axis >= resShape.size())
throw std::out_of_range("axis is not a dimension of shape");
for(size_t i = 1; i < v.size(); ++i) {
if(v[i].shape().size() != resShape.size())
throw std::logic_error("shapes have different dimensionalities");
for(size_t dim = 0; dim < resShape.size(); ++dim) {
if(dim == axis)
continue;
if(v[i].shape(dim) != resShape[dim])
throw std::logic_error("incompatible shapes");
}
resShape[axis] += v[i].shape(axis);
}
xt::xarray<T> result(resShape);
std::vector<size_t> dstIndex(resShape.size(), 0);
for(size_t i = 0; i < v.size(); ++i) {
const size_t axisIndexOffset = dstIndex[axis];
ShapeIndex srcIndex(v[i].shape());
do {
for(size_t dim = 0; dim < resShape.size(); ++dim) {
if(dim == axis)
dstIndex[dim] = axisIndexOffset + srcIndex.index[dim];
else
dstIndex[dim] = srcIndex.index[dim];
}
result[dstIndex] = v[i][srcIndex.index];
}
while(srcIndex.increment());
dstIndex[axis] += v[i].shape(axis);
}
return result;
}
#endif // XTCONCATENATE_HPP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment