Skip to content

Instantly share code, notes, and snippets.

@onihusube
Created June 30, 2023 03:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save onihusube/c57b4f8d3d60274578930300cf3d9a38 to your computer and use it in GitHub Desktop.
Save onihusube/c57b4f8d3d60274578930300cf3d9a38 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <concepts>
#include <mdspan>
#include <ranges>
#include <cassert>
using namespace std::experimental;
template<std::unsigned_integral auto D>
struct layout_right_interleaved {
template <class Extents>
class mapping {
[[no_unique_address]]
Extents m_extent;
public:
using extents_type = Extents;
using index_type = typename extents_type::index_type;
using rank_type = typename extents_type::rank_type;
using layout_type = layout_right_interleaved<D>;
private:
static constexpr auto calc_static_stride() -> std::array<index_type, Extents::rank()> {
// 動的エクステントの場合は空
if constexpr (Extents::rank_dynamic() != 0) {
return {};
}
// 全て静的なら、各次元のストライドを求める
std::array<index_type, Extents::rank()> stride;
stride[0] = D;
for (auto i = 1u; i < Extents::rank(); ++i) {
stride[i] = stride[i - 1] * Extents::static_extent(i - 1);
}
return stride;
}
static constexpr std::array<index_type, Extents::rank()> static_stride = calc_static_stride();
public:
constexpr mapping(const extents_type& ex)
: m_extent{ex}
{}
mapping(const mapping &) = default;
mapping &operator=(const mapping &) & = default;
friend bool operator==(mapping, mapping) = default;
constexpr auto extents() const noexcept -> const extents_type& {
return m_extent;
}
constexpr auto required_span_size() const -> index_type {
// 丸投げ
return layout_stride::mapping<Extents>(m_extent).required_span_size();
}
constexpr auto stride(rank_type r) const -> index_type
requires (extents_type::rank() != 0)
{
assert(r < extents_type::rank());
if constexpr (Extents::rank_dynamic() != 0) {
index_type stride = D;
for (auto i : std::views::iota(0u, r)) {
stride *= m_extent.extent(i);
}
return stride;
} else {
return static_stride[r];
}
}
template<typename... Indices>
requires (sizeof...(Indices) == extents_type::rank()) and
(std::is_nothrow_convertible_v<Indices, index_type> && ...)
constexpr auto operator()(Indices... indices) const -> index_type {
// 行列次元数
// extent()が0indexなので最大値は-1
constexpr std::unsigned_integral auto N = extents_type::rank() - 1;
static_assert(0u < N);
// インデックス配列
// indicesは先頭が最大次元、末尾が1次元
const std::array<index_type, extents_type::rank()> idx_array = {static_cast<index_type>(indices)...};
if constexpr (Extents::rank_dynamic() != 0) {
// 動的エクステントを含む場合の計算
index_type idx = idx_array[0];
for (auto m = N - 1; const auto in : idx_array | std::views::drop(1)) {
idx *= m_extent.extent(m);
idx += in;
--m;
}
return D * idx;
} else {
// 静的エクステントを含む場合の計算
index_type idx = 0;
for (index_type i = 0; const auto st : static_stride | std::views::reverse) {
idx += st * idx_array[i];
++i;
}
return idx;
}
}
static constexpr bool is_unique() noexcept {
return true;
}
static constexpr bool is_exhaustive() noexcept {
return D == 1;
}
static constexpr bool is_strided() noexcept {
return true;
}
static constexpr bool is_always_unique() noexcept {
return true;
}
static constexpr bool is_always_exhaustive() noexcept {
return D == 1;
}
static constexpr bool is_always_strided() noexcept {
return true;
}
};
};
void test() {
using test = layout_right_interleaved<3u>::mapping<extents<std::size_t, 3, 3>>;
test m{extents<std::size_t, 3, 3>{}};
std::cout << m.stride(0) << '\n';
std::cout << m.stride(1) << '\n';
std::cout << "(0, 0) -> " << m(0, 0) << '\n';
std::cout << "(0, 1) -> " << m(0, 1) << '\n';
std::cout << "(1, 0) -> " << m(1, 0) << '\n';
std::cout << "(1, 1) -> " << m(1, 1) << '\n';
std::cout << "(2, 2) -> " << m(2, 2) << '\n';
std::cout << "stride(0) = " << m.stride(0) << '\n';
std::cout << "stride(1) = " << m.stride(1) << '\n';
std::cout << "size : " << sizeof(m) << '\n';
}
template <typename T>
using stride_interleaved_mat33 = mdspan<T, extents<std::size_t, 3, 3>, layout_stride>;
template <typename T, std::unsigned_integral auto D>
using interleaved_mat33 = mdspan<T, extents<std::size_t, 3, 3>, layout_right_interleaved<D>>;
template <typename T, std::unsigned_integral auto D, typename E>
using interleaved_mat = mdspan<T, E, layout_right_interleaved<D>>;
template <typename T, typename E, typename L>
void print_mat(mdspan<T, E, L> mat33) {
assert(E::rank() == 2);
assert(mat33.extent(0) == 3);
assert(mat33.extent(1) == 3);
for (int y = 0; y < 3; ++y) {
for (int x = 0; x < 3; ++x) {
std::cout << mat33[y, x] << ' ';
}
std::cout << '\n';
}
std::cout << '\n';
}
int main() {
int storage[] = {
111, 211, 311, 112, 212, 312, 113, 213, 313,
121, 221, 321, 122, 222, 322, 123, 223, 323,
131, 231, 331, 132, 232, 332, 133, 233, 333
};
/*using mapping = stride_interleaved_mat33<int>::mapping_type;
// この場合の各次元のストライド
constexpr std::array<std::size_t, 2> stride = {9, 3};
stride_interleaved_mat33 A{storage, mapping{{}, stride}};
stride_interleaved_mat33 B{storage + 1, mapping{{}, stride}};
stride_interleaved_mat33 C{storage + 2, mapping{{}, stride}};*/
//interleaved_mat33<int, 3u> A{storage};
//interleaved_mat33<int, 3u> B{storage + 1};
//interleaved_mat33<int, 3u> C{storage + 2};
interleaved_mat<int, 3u, extents<std::size_t, 3, 3>> A{storage};
interleaved_mat<int, 3u, extents<std::size_t, 3, dynamic_extent>> B{storage + 1, extents<std::size_t, 3, dynamic_extent>{3}};
interleaved_mat<int, 3u, dextents<std::size_t, 2>> C{storage + 2, dextents<std::size_t, 2>{3, 3}};
print_mat(A);
print_mat(B);
print_mat(C);
std::cout << "----\n";
test();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment