Created
June 30, 2023 03:54
-
-
Save onihusube/c57b4f8d3d60274578930300cf3d9a38 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <concepts> | |
#include <mdspan> | |
#include <ranges> | |
#include <cassert> | |
using namespace std::experimental; | |
template<std::unsigned_integral auto D> | |
struct layout_right_interleaved { | |
template <class Extents> | |
class mapping { | |
[[no_unique_address]] | |
Extents m_extent; | |
public: | |
using extents_type = Extents; | |
using index_type = typename extents_type::index_type; | |
using rank_type = typename extents_type::rank_type; | |
using layout_type = layout_right_interleaved<D>; | |
private: | |
static constexpr auto calc_static_stride() -> std::array<index_type, Extents::rank()> { | |
// 動的エクステントの場合は空 | |
if constexpr (Extents::rank_dynamic() != 0) { | |
return {}; | |
} | |
// 全て静的なら、各次元のストライドを求める | |
std::array<index_type, Extents::rank()> stride; | |
stride[0] = D; | |
for (auto i = 1u; i < Extents::rank(); ++i) { | |
stride[i] = stride[i - 1] * Extents::static_extent(i - 1); | |
} | |
return stride; | |
} | |
static constexpr std::array<index_type, Extents::rank()> static_stride = calc_static_stride(); | |
public: | |
constexpr mapping(const extents_type& ex) | |
: m_extent{ex} | |
{} | |
mapping(const mapping &) = default; | |
mapping &operator=(const mapping &) & = default; | |
friend bool operator==(mapping, mapping) = default; | |
constexpr auto extents() const noexcept -> const extents_type& { | |
return m_extent; | |
} | |
constexpr auto required_span_size() const -> index_type { | |
// 丸投げ | |
return layout_stride::mapping<Extents>(m_extent).required_span_size(); | |
} | |
constexpr auto stride(rank_type r) const -> index_type | |
requires (extents_type::rank() != 0) | |
{ | |
assert(r < extents_type::rank()); | |
if constexpr (Extents::rank_dynamic() != 0) { | |
index_type stride = D; | |
for (auto i : std::views::iota(0u, r)) { | |
stride *= m_extent.extent(i); | |
} | |
return stride; | |
} else { | |
return static_stride[r]; | |
} | |
} | |
template<typename... Indices> | |
requires (sizeof...(Indices) == extents_type::rank()) and | |
(std::is_nothrow_convertible_v<Indices, index_type> && ...) | |
constexpr auto operator()(Indices... indices) const -> index_type { | |
// 行列次元数 | |
// extent()が0indexなので最大値は-1 | |
constexpr std::unsigned_integral auto N = extents_type::rank() - 1; | |
static_assert(0u < N); | |
// インデックス配列 | |
// indicesは先頭が最大次元、末尾が1次元 | |
const std::array<index_type, extents_type::rank()> idx_array = {static_cast<index_type>(indices)...}; | |
if constexpr (Extents::rank_dynamic() != 0) { | |
// 動的エクステントを含む場合の計算 | |
index_type idx = idx_array[0]; | |
for (auto m = N - 1; const auto in : idx_array | std::views::drop(1)) { | |
idx *= m_extent.extent(m); | |
idx += in; | |
--m; | |
} | |
return D * idx; | |
} else { | |
// 静的エクステントを含む場合の計算 | |
index_type idx = 0; | |
for (index_type i = 0; const auto st : static_stride | std::views::reverse) { | |
idx += st * idx_array[i]; | |
++i; | |
} | |
return idx; | |
} | |
} | |
static constexpr bool is_unique() noexcept { | |
return true; | |
} | |
static constexpr bool is_exhaustive() noexcept { | |
return D == 1; | |
} | |
static constexpr bool is_strided() noexcept { | |
return true; | |
} | |
static constexpr bool is_always_unique() noexcept { | |
return true; | |
} | |
static constexpr bool is_always_exhaustive() noexcept { | |
return D == 1; | |
} | |
static constexpr bool is_always_strided() noexcept { | |
return true; | |
} | |
}; | |
}; | |
void test() { | |
using test = layout_right_interleaved<3u>::mapping<extents<std::size_t, 3, 3>>; | |
test m{extents<std::size_t, 3, 3>{}}; | |
std::cout << m.stride(0) << '\n'; | |
std::cout << m.stride(1) << '\n'; | |
std::cout << "(0, 0) -> " << m(0, 0) << '\n'; | |
std::cout << "(0, 1) -> " << m(0, 1) << '\n'; | |
std::cout << "(1, 0) -> " << m(1, 0) << '\n'; | |
std::cout << "(1, 1) -> " << m(1, 1) << '\n'; | |
std::cout << "(2, 2) -> " << m(2, 2) << '\n'; | |
std::cout << "stride(0) = " << m.stride(0) << '\n'; | |
std::cout << "stride(1) = " << m.stride(1) << '\n'; | |
std::cout << "size : " << sizeof(m) << '\n'; | |
} | |
template <typename T> | |
using stride_interleaved_mat33 = mdspan<T, extents<std::size_t, 3, 3>, layout_stride>; | |
template <typename T, std::unsigned_integral auto D> | |
using interleaved_mat33 = mdspan<T, extents<std::size_t, 3, 3>, layout_right_interleaved<D>>; | |
template <typename T, std::unsigned_integral auto D, typename E> | |
using interleaved_mat = mdspan<T, E, layout_right_interleaved<D>>; | |
template <typename T, typename E, typename L> | |
void print_mat(mdspan<T, E, L> mat33) { | |
assert(E::rank() == 2); | |
assert(mat33.extent(0) == 3); | |
assert(mat33.extent(1) == 3); | |
for (int y = 0; y < 3; ++y) { | |
for (int x = 0; x < 3; ++x) { | |
std::cout << mat33[y, x] << ' '; | |
} | |
std::cout << '\n'; | |
} | |
std::cout << '\n'; | |
} | |
int main() { | |
int storage[] = { | |
111, 211, 311, 112, 212, 312, 113, 213, 313, | |
121, 221, 321, 122, 222, 322, 123, 223, 323, | |
131, 231, 331, 132, 232, 332, 133, 233, 333 | |
}; | |
/*using mapping = stride_interleaved_mat33<int>::mapping_type; | |
// この場合の各次元のストライド | |
constexpr std::array<std::size_t, 2> stride = {9, 3}; | |
stride_interleaved_mat33 A{storage, mapping{{}, stride}}; | |
stride_interleaved_mat33 B{storage + 1, mapping{{}, stride}}; | |
stride_interleaved_mat33 C{storage + 2, mapping{{}, stride}};*/ | |
//interleaved_mat33<int, 3u> A{storage}; | |
//interleaved_mat33<int, 3u> B{storage + 1}; | |
//interleaved_mat33<int, 3u> C{storage + 2}; | |
interleaved_mat<int, 3u, extents<std::size_t, 3, 3>> A{storage}; | |
interleaved_mat<int, 3u, extents<std::size_t, 3, dynamic_extent>> B{storage + 1, extents<std::size_t, 3, dynamic_extent>{3}}; | |
interleaved_mat<int, 3u, dextents<std::size_t, 2>> C{storage + 2, dextents<std::size_t, 2>{3, 3}}; | |
print_mat(A); | |
print_mat(B); | |
print_mat(C); | |
std::cout << "----\n"; | |
test(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment