Skip to content

Instantly share code, notes, and snippets.

@metab0t
Created April 23, 2019 08:30
Show Gist options
  • Save metab0t/0fa3a1dfc3a204d606803bf6bd59fe50 to your computer and use it in GitHub Desktop.
Save metab0t/0fa3a1dfc3a204d606803bf6bd59fe50 to your computer and use it in GitHub Desktop.
Stride of multidimensional array
#include <vector>
#include <functional>
#include <numeric>
#include <cassert>
using std::vector;
struct Stride {
size_t n_dims, n_elem;
vector<size_t> dims;
vector<size_t> strides;
Stride(const vector<size_t> &shape)
: dims(shape), n_dims(shape.size()), strides(n_dims, 1) {
n_elem =
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
for (int i = n_dims - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * dims[i + 1];
}
}
size_t sub2ind(const vector<size_t> &subscripts) {
assert(n_dims == subscripts.size());
size_t n = 0;
for (auto i = 0; i < n_dims; i++) {
n += subscripts[i] * strides[i];
}
return n;
}
vector<size_t> ind2sub(size_t index) {
vector<size_t> subscripts(n_dims);
for (auto i = 0; i < n_dims; i++) {
subscripts[i] = index / strides[i];
index = index % strides[i];
}
return subscripts;
}
};
int main(int argc, char const *argv[]) {
auto s = Stride({2, 3, 4});
assert(s.n_dims == 3);
assert(s.n_elem == 2 * 3 * 4);
assert(s.sub2ind({1, 2, 3}) == s.n_elem - 1);
assert(s.ind2sub(s.n_elem - 1) == vector<size_t>({1, 2, 3}));
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment