Skip to content

Instantly share code, notes, and snippets.

@PatWie
Last active June 3, 2023 17:06
Show Gist options
  • Save PatWie/07b2962f75446250c138f686a6e3da0d to your computer and use it in GitHub Desktop.
Save PatWie/07b2962f75446250c138f686a6e3da0d to your computer and use it in GitHub Desktop.
MultiIndex
g++ main.cc -std=c++17 && ./a.out
#ifndef INDEX_H_
#define INDEX_H_
#include <cstddef>
#include <tuple>
namespace internal {
template <size_t TRank, size_t TSkip, size_t TPos, size_t TRemaining>
struct pitch_helper {
constexpr size_t call(const size_t dimensions_[TRank]) const {
return pitch_helper<TRank, TSkip - 1, TPos + 1, TRank - TPos - 1>().call(
dimensions_);
}
};
template <size_t TRank, size_t TPos, size_t TRemaining>
struct pitch_helper<TRank, 0, TPos, TRemaining> {
constexpr size_t call(const size_t dimensions_[TRank]) const {
return dimensions_[TPos] *
pitch_helper<TRank, 0, TPos + 1, TRemaining - 1>().call(dimensions_);
}
};
template <size_t TRank, size_t TPos> struct pitch_helper<TRank, 0, TPos, 0> {
constexpr size_t call(const size_t dimensions_[TRank]) const { return 1; }
};
template <size_t TRank, size_t TRemaining, class T, class... Ts>
struct position_helper {
constexpr size_t call(const size_t dimensions_[TRank], T v, Ts... is) const {
return v * pitch_helper<TRank, TRank - TRemaining + 1, 0, TRank>().call(
dimensions_) +
position_helper<TRank, TRemaining - 1, Ts...>().call(dimensions_,
is...);
}
};
template <size_t TRank, size_t TRemaining, class T>
struct position_helper<TRank, TRemaining, T> {
constexpr size_t call(const size_t dimensions_[TRank], T v) const {
return v;
}
};
template <size_t TRank, size_t TPos, size_t TRemaining>
struct unflatten_helper {
template <class... Ts>
static constexpr auto call(const size_t dimensions_[TRank],
size_t flattenedIndex, Ts &...indices) noexcept {
const size_t pitch =
pitch_helper<TRank, 1, TPos, TRank - 1>().call(dimensions_);
const size_t index = flattenedIndex / pitch;
return unflatten_helper<TRank, TPos + 1, TRemaining - 1>::call(
dimensions_, flattenedIndex % pitch, indices..., index);
}
};
template <size_t TRank, size_t TPos> struct unflatten_helper<TRank, TPos, 1> {
template <class... Ts>
static constexpr auto call(const size_t dimensions_[TRank],
size_t flattenedIndex, Ts &...indices) noexcept {
return std::make_tuple(indices..., flattenedIndex);
}
};
}; // namespace internal
template <size_t TRank> struct BaseNdIndex {
protected:
size_t dimensions_[TRank];
public:
template <class... Ts>
explicit constexpr inline BaseNdIndex(size_t i0, Ts... is) noexcept
: dimensions_{i0, is...} {}
/**
* Check whether given coordinate is in range.
*/
template <class... Ts>
constexpr inline bool valid(size_t i0, Ts... is) const {
static_assert(size_t(1) + sizeof...(Ts) == TRank,
"Number of dimensions does not match rank! "
"YOU_MADE_A_PROGAMMING_MISTAKE");
return valid_impl<0, Ts...>(i0, is...);
}
/**
* Return the number of axes.
* @return number of axes
*/
constexpr inline size_t rank() const { return TRank; }
/**
* Return the dimension for a given axis.
*
* const size_t D = my_nd_array.template dim<1>();
*
* @return dimension for given axis
*/
template <size_t TAxis> constexpr inline size_t dim() const {
static_assert(TAxis < TRank, "axis < rank failed");
return dimensions_[TAxis];
}
/**
* Unflatten a flattened index and retrieve the corresponding
* indices for each dimension.
*
* size_t i, j, k;
* idx.unflatten(flattenedIndex, i, j, k);
*
* @param flattenedIndex the flattened index to unflatten
* @param indices references to variables to store the indices
*/
constexpr inline auto unflatten(size_t flattenedIndex) const noexcept {
return internal::unflatten_helper<TRank, 0, TRank>::call(dimensions_,
flattenedIndex);
}
private:
template <size_t TNum, class... Ts>
constexpr inline bool valid_impl(size_t i0, Ts... is) const {
return (i0 < dimensions_[TNum]) && valid_impl<TNum + 1, Ts...>(is...);
}
template <size_t TNum, typename T>
constexpr inline bool valid_impl(T i0) const {
return (i0 < dimensions_[TRank - 1]);
}
protected:
template <class... Ts>
constexpr inline size_t index_(size_t i0, Ts... is) const {
return internal::position_helper<TRank, TRank, size_t, Ts...>().call(
dimensions_, i0, is...);
}
};
/**
* Create an index object.
*
* The index object can handle various dimensions.
*
* auto idx = NdIndex<4>(B, H, W, C);
* auto TPos = idx(b, h, w, c);
*
* @param rank in each dimensions.
*/
template <size_t TRank> struct NdIndex : public BaseNdIndex<TRank> {
public:
template <class... Ts>
explicit constexpr inline NdIndex(size_t i0, Ts... is) noexcept
: BaseNdIndex<TRank>(i0, is...) {
static_assert(size_t(1) + sizeof...(Ts) == TRank,
"Number of dimensions does not match rank! "
"YOU_MADE_A_PROGAMMING_MISTAKE");
}
/**
* Get flattened index for a given position.
*
* auto idx = NdIndex<4>(10, 20, 30, 40);
* size_t actual = idx(1, 2, 3, 4);
* size_t expected = 1 * (20 * 30 * 40) + 2 * (30 * 40) + 3 * (40) + 4;
*/
template <class... Ts> size_t inline operator()(size_t i0, Ts... is) const {
static_assert(size_t(1) + sizeof...(Ts) == TRank,
"Number of dimensions does not match rank! "
"YOU_MADE_A_PROGAMMING_MISTAKE");
return this->index_(i0, is...);
}
/**
* Get dimension for a given axis.
*
* auto idx = NdIndex<4>(10, 20, 30, 40);
* size_t actual = idx[1]; // is 20
*/
template <class... Ts> size_t inline operator[](size_t i0) const {
return BaseNdIndex<TRank>::dimensions_[i0];
}
};
//////////////
#endif // INDEX_H_
#ifndef INDEX_H_
#define INDEX_H_
#include <cstddef>
namespace internal {
template <size_t TRank, size_t TSkip, size_t TPos, size_t TRemaining>
struct pitch_helper {
constexpr size_t call(const size_t dimensions_[TRank]) const {
return pitch_helper<TRank, TSkip - 1, TPos + 1, TRank - TPos - 1>().call(
dimensions_);
}
};
template <size_t TRank, size_t TPos, size_t TRemaining>
struct pitch_helper<TRank, 0, TPos, TRemaining> {
constexpr size_t call(const size_t dimensions_[TRank]) const {
return dimensions_[TPos] *
pitch_helper<TRank, 0, TPos + 1, TRemaining - 1>().call(dimensions_);
}
};
template <size_t TRank, size_t TPos> struct pitch_helper<TRank, 0, TPos, 0> {
constexpr size_t call(const size_t dimensions_[TRank]) const { return 1; }
};
template <size_t TRank, size_t TRemaining, class T, class... Ts>
struct position_helper {
constexpr size_t call(const size_t dimensions_[TRank], T v, Ts... is) const {
return v * pitch_helper<TRank, TRank - TRemaining + 1, 0, TRank>().call(
dimensions_) +
position_helper<TRank, TRemaining - 1, Ts...>().call(dimensions_,
is...);
}
};
template <size_t TRank, size_t TRemaining, class T>
struct position_helper<TRank, TRemaining, T> {
constexpr size_t call(const size_t dimensions_[TRank], T v) const {
return v;
}
};
template <size_t TRank, size_t TPos, size_t TRemaining>
struct unflatten_helper {
template <class... Ts>
static constexpr void call(const size_t dimensions_[TRank],
size_t flattenedIndex, size_t &index,
Ts &...indices) noexcept {
const size_t pitch =
pitch_helper<TRank, 1, TPos, TRank - 1>().call(dimensions_);
index = flattenedIndex / pitch;
unflatten_helper<TRank, TPos + 1, TRemaining - 1>::call(
dimensions_, flattenedIndex % pitch, indices...);
}
};
template <size_t TRank, size_t TPos> struct unflatten_helper<TRank, TPos, 1> {
template <class... Ts>
static constexpr void call(const size_t dimensions_[TRank],
size_t flattenedIndex, size_t &index,
Ts &...indices) noexcept {
index = flattenedIndex;
}
};
}; // namespace internal
template <size_t TRank> struct BaseNdIndex {
protected:
size_t dimensions_[TRank];
public:
template <class... Ts>
explicit constexpr inline BaseNdIndex(size_t i0, Ts... is) noexcept
: dimensions_{i0, is...} {}
/**
* Check whether given coordinate is in range.
*/
template <class... Ts>
constexpr inline bool valid(size_t i0, Ts... is) const {
static_assert(size_t(1) + sizeof...(Ts) == TRank,
"Number of dimensions does not match rank! "
"YOU_MADE_A_PROGAMMING_MISTAKE");
return valid_impl<0, Ts...>(i0, is...);
}
/**
* Return the number of axes.
* @return number of axes
*/
constexpr inline size_t rank() const { return TRank; }
/**
* Return the dimension for a given axis.
*
* const size_t D = my_nd_array.template dim<1>();
*
* @return dimension for given axis
*/
template <size_t TAxis> constexpr inline size_t dim() const {
static_assert(TAxis < TRank, "axis < rank failed");
return dimensions_[TAxis];
}
/**
* Unflatten a flattened index and retrieve the corresponding
* indices for each dimension.
*
* size_t i, j, k;
* idx.unflatten(flattenedIndex, i, j, k);
*
* @param flattenedIndex the flattened index to unflatten
* @param indices references to variables to store the indices
*/
template <class... Ts>
constexpr inline void unflatten(size_t flattenedIndex,
Ts &...indices) const noexcept {
static_assert(sizeof...(Ts) == TRank,
"Number of indices does not match rank! "
"YOU_MADE_A_PROGAMMING_MISTAKE");
internal::unflatten_helper<TRank, 0, TRank>::call(
dimensions_, flattenedIndex, indices...);
}
private:
template <size_t TNum, class... Ts>
constexpr inline bool valid_impl(size_t i0, Ts... is) const {
return (i0 < dimensions_[TNum]) && valid_impl<TNum + 1, Ts...>(is...);
}
template <size_t TNum, typename T>
constexpr inline bool valid_impl(T i0) const {
return (i0 < dimensions_[TRank - 1]);
}
protected:
template <class... Ts>
constexpr inline size_t index_(size_t i0, Ts... is) const {
return internal::position_helper<TRank, TRank, size_t, Ts...>().call(
dimensions_, i0, is...);
}
};
/**
* Create an index object.
*
* The index object can handle various dimensions.
*
* auto idx = NdIndex<4>(B, H, W, C);
* auto TPos = idx(b, h, w, c);
*
* @param rank in each dimensions.
*/
template <size_t TRank> struct NdIndex : public BaseNdIndex<TRank> {
public:
template <class... Ts>
explicit constexpr inline NdIndex(size_t i0, Ts... is) noexcept
: BaseNdIndex<TRank>(i0, is...) {
static_assert(size_t(1) + sizeof...(Ts) == TRank,
"Number of dimensions does not match rank! "
"YOU_MADE_A_PROGAMMING_MISTAKE");
}
/**
* Get flattened index for a given position.
*
* auto idx = NdIndex<4>(10, 20, 30, 40);
* size_t actual = idx(1, 2, 3, 4);
* size_t expected = 1 * (20 * 30 * 40) + 2 * (30 * 40) + 3 * (40) + 4;
*/
template <class... Ts> size_t inline operator()(size_t i0, Ts... is) const {
static_assert(size_t(1) + sizeof...(Ts) == TRank,
"Number of dimensions does not match rank! "
"YOU_MADE_A_PROGAMMING_MISTAKE");
return this->index_(i0, is...);
}
/**
* Get dimension for a given axis.
*
* auto idx = NdIndex<4>(10, 20, 30, 40);
* size_t actual = idx[1]; // is 20
*/
template <class... Ts> size_t inline operator[](size_t i0) const {
return BaseNdIndex<TRank>::dimensions_[i0];
}
};
//////////////
#endif // INDEX_H_
#include "index.h"
#include <cassert>
#include <iostream>
int main() {
size_t dim_b = 100;
size_t dim_h = 200;
size_t dim_w = 300;
size_t dim_c = 400;
size_t b = 10;
size_t h = 101;
size_t w = 13;
size_t c = 87;
auto index = NdIndex<4>(dim_b, dim_h, dim_w, dim_c);
size_t flattened_index = index(b, h, w, c);
size_t b_ = 0, h_ = 0, w_ = 0, c_ = 0;
index.unflatten(flattened_index, b_, h_, w_, c_);
// C++17
auto [b_, h_, w_, c_] = index.unflatten(flattened_index);
assert(b == b_);
assert(h == h_);
assert(w == w_);
assert(c == c_);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment