Created
April 23, 2019 08:30
-
-
Save metab0t/0fa3a1dfc3a204d606803bf6bd59fe50 to your computer and use it in GitHub Desktop.
Stride of multidimensional array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <vector> | |
#include <functional> | |
#include <numeric> | |
#include <cassert> | |
using std::vector; | |
struct Stride { | |
size_t n_dims, n_elem; | |
vector<size_t> dims; | |
vector<size_t> strides; | |
Stride(const vector<size_t> &shape) | |
: dims(shape), n_dims(shape.size()), strides(n_dims, 1) { | |
n_elem = | |
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); | |
for (int i = n_dims - 2; i >= 0; i--) { | |
strides[i] = strides[i + 1] * dims[i + 1]; | |
} | |
} | |
size_t sub2ind(const vector<size_t> &subscripts) { | |
assert(n_dims == subscripts.size()); | |
size_t n = 0; | |
for (auto i = 0; i < n_dims; i++) { | |
n += subscripts[i] * strides[i]; | |
} | |
return n; | |
} | |
vector<size_t> ind2sub(size_t index) { | |
vector<size_t> subscripts(n_dims); | |
for (auto i = 0; i < n_dims; i++) { | |
subscripts[i] = index / strides[i]; | |
index = index % strides[i]; | |
} | |
return subscripts; | |
} | |
}; | |
int main(int argc, char const *argv[]) { | |
auto s = Stride({2, 3, 4}); | |
assert(s.n_dims == 3); | |
assert(s.n_elem == 2 * 3 * 4); | |
assert(s.sub2ind({1, 2, 3}) == s.n_elem - 1); | |
assert(s.ind2sub(s.n_elem - 1) == vector<size_t>({1, 2, 3})); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment