Created
November 23, 2017 06:43
-
-
Save DrPizza/39f90bd2324f0100f487eb6e27870d29 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0009r4.html | |
#include <utility> | |
#include <array> | |
#include <algorithm> | |
#include <functional> | |
#include <type_traits> | |
#include <memory> | |
#include <exception> | |
#include <stdexcept> | |
#include <cstddef> | |
#include <numeric> | |
#include <iostream> | |
#include <gsl/gsl> | |
namespace md { | |
template<typename T, ptrdiff_t E = gsl::dynamic_extent> | |
using span = gsl::span<T, E>; | |
constexpr ptrdiff_t dynamic_extent = gsl::dynamic_extent; | |
namespace { | |
template<ptrdiff_t... Ns> | |
constexpr size_t count_dynamics() { | |
size_t count = 0; | |
using swallow = size_t[]; | |
static_cast<void>(swallow{ (count += (Ns == dynamic_extent))... }); | |
return count; | |
} | |
template<template<typename...> typename Pred, typename... Types> | |
constexpr std::size_t find_first_type_of() { | |
constexpr bool matches[] = { Pred<Types>::value... }; | |
for(std::size_t i = 0; i < sizeof...(Types); ++i) { | |
if(matches[i]) { | |
return i; | |
} | |
} | |
return ~static_cast<std::size_t>(0); | |
} | |
template<template<typename...> typename Pred, typename... Types> | |
constexpr std::size_t find_first_type_of_v = find_first_type_of<Pred, Types...>(); | |
template<std::size_t I, typename... Types> | |
struct select_nth; | |
template<typename T, typename... Ts> | |
struct select_nth<0, T, Ts...> { | |
using type = T; | |
}; | |
template<std::size_t I, typename T, typename... Ts> | |
struct select_nth<I, T, Ts...> : select_nth<I - 1, Ts...> { | |
static_assert(I <= sizeof...(Ts), "bad index"); | |
}; | |
template<size_t I, typename... Ts> | |
using select_nth_t = typename select_nth<I, Ts...>::type; | |
template<template<typename...> typename Predicate, typename... Ts> | |
struct first_matching_type { | |
using type = select_nth_t<find_first_type_of_v<Predicate, Ts...>, Ts...>; | |
}; | |
template<template<typename...> typename Predicate, typename... Ts> | |
using first_matching_type_t = typename first_matching_type<Predicate, Ts...>::type; | |
} | |
template<typename It1, typename It2, typename T, typename Fn1, typename Fn2> | |
constexpr T inner_product(It1 first1, It1 end1, It2 first2, T val, Fn1 f1, Fn2 f2) { | |
for(; first1 != end1; static_cast<void>(++first1), static_cast<void>(++first2)) { | |
val = f1(val, f2(*first1, *first2)); | |
} | |
return val; | |
} | |
template<typename It, typename T, typename Fn> | |
constexpr T accumulate(It first, It end, T val, Fn f) { | |
for(; first != end; ++first) { | |
val = f(val, *first); | |
} | |
return val; | |
} | |
template<ptrdiff_t... Extents> | |
struct extents { | |
using index_type = ptrdiff_t; | |
using array_type = index_type[]; | |
using my_type = extents<Extents...>; | |
constexpr extents() noexcept : values{ Extents... } { | |
} | |
constexpr extents(const extents&) = default; | |
constexpr extents(extents&&) = default; | |
~extents() = default; | |
extents& operator=(const extents&) noexcept = default; | |
extents& operator=(extents&&) noexcept = default; | |
static constexpr size_t rank() noexcept { | |
return sizeof...(Extents); | |
} | |
static constexpr size_t rank_dynamic() noexcept { | |
return count_dynamics<Extents...>(); | |
} | |
static constexpr index_type static_extent(size_t e) noexcept { | |
if(e >= rank()) { | |
return 1; | |
} else { | |
return array_type{ Extents... }[e]; | |
} | |
} | |
constexpr index_type extent(size_t e) const noexcept { | |
if(e >= rank()) { | |
return 1; | |
} else { | |
return values[e]; | |
} | |
} | |
constexpr index_type size() const noexcept { | |
return md::accumulate(values.cbegin(), values.cend(), static_cast<index_type>(1), std::multiplies<>()); | |
} | |
template<typename IndexType, size_t N> | |
constexpr extents(const std::array<IndexType, N>& dynamic_extents) noexcept : values{ Extents... } { | |
static_assert(N == rank_dynamic(), "wrong number of dynamic extents provided"); | |
for(size_t i = 0, j = 0; i < rank(); ++i) { | |
if(values[i] == dynamic_extent) { | |
values[i] = dynamic_extents[j++]; | |
} | |
} | |
} | |
template<typename... IndexType> | |
constexpr extents(IndexType... DynamicExtents) noexcept : extents(std::array<index_type, sizeof...(IndexType)> { DynamicExtents... }) { | |
} | |
private: | |
std::array<index_type, my_type::rank()> values; | |
}; | |
template<typename E, typename = void> | |
struct is_extent : std::false_type { | |
}; | |
template<typename E> | |
struct is_extent<E, std::void_t< | |
typename E::index_type, | |
decltype(E::rank()), | |
decltype(E::rank_dynamic()), | |
decltype(E::static_extent(std::declval<size_t>())), | |
decltype(std::declval<E>().extent(std::declval<size_t>())), | |
decltype(std::declval<E>().size()) | |
> > : std::true_type { | |
}; | |
template<typename E> | |
constexpr bool is_extent_v = is_extent<E>::value; | |
namespace { | |
template<typename Dimensions> | |
struct mapping_base { | |
using index_type = ptrdiff_t; | |
using array_type = index_type[]; | |
using extent_type = Dimensions; | |
constexpr mapping_base() noexcept : strides{} {} | |
constexpr mapping_base(const mapping_base&) noexcept = default; | |
constexpr mapping_base(mapping_base&&) noexcept = default; | |
~mapping_base() noexcept = default; | |
mapping_base& operator=(const mapping_base&) noexcept = default; | |
mapping_base& operator=(mapping_base&&) noexcept = default; | |
template<typename IndexType, size_t N> | |
constexpr mapping_base(const std::array<IndexType, N>& dynamic_extents) noexcept : dimensions(dynamic_extents), strides{} { | |
} | |
template<typename... IndexType> | |
constexpr mapping_base(IndexType... DynamicExtents) noexcept : mapping_base(std::array<index_type, sizeof...(DynamicExtents)>{ DynamicExtents... }) { | |
} | |
static constexpr size_t rank() noexcept { | |
return extent_type::rank(); | |
} | |
static constexpr size_t rank_dynamic() noexcept { | |
return extent_type::rank_dynamic(); | |
} | |
static constexpr index_type static_extent(size_t i) noexcept { | |
return extent_type::static_extent(i); | |
} | |
constexpr index_type extent(size_t i) const noexcept { | |
return dimensions.extent(i); | |
} | |
constexpr index_type span_size() const noexcept { | |
return dimensions.size(); | |
} | |
static constexpr bool is_always_unique = true; | |
static constexpr bool is_always_contiguous = true; | |
static constexpr bool is_always_strided = true; | |
constexpr bool is_unique() const noexcept { | |
return true; | |
} | |
constexpr bool is_contiguous() const noexcept { | |
return true; | |
} | |
constexpr bool is_strided() noexcept { | |
return true; | |
} | |
constexpr index_type stride(size_t r) const noexcept { | |
if(r >= rank()) { | |
return 0; | |
} | |
return strides[r]; | |
} | |
template<typename IndexType, size_t N> | |
constexpr index_type operator()(const std::array<IndexType, N>& idxes) const noexcept { | |
static_assert(N == rank(), "wrong number of indices passed"); | |
return md::inner_product(strides.begin(), strides.end(), idxes.begin(), static_cast<index_type>(0), std::plus<>(), std::multiplies<>()); | |
} | |
template<typename... IndexType > | |
constexpr index_type operator()(IndexType... indices) const noexcept { | |
return operator()(std::array<index_type, sizeof...(IndexType)>{ static_cast<index_type>(indices)... }); | |
} | |
protected: | |
extent_type dimensions; | |
std::array<size_t, extent_type::rank()> strides; | |
}; | |
} | |
struct layout_right { | |
template<typename Dimensions> | |
struct mapping : mapping_base<Dimensions> { | |
using base_type = mapping_base<Dimensions>; | |
using index_type = typename base_type::index_type; | |
constexpr mapping(const mapping&) noexcept = default; | |
constexpr mapping(mapping&&) noexcept = default; | |
~mapping() noexcept = default; | |
mapping& operator=(const mapping&) noexcept = default; | |
mapping& operator=(mapping&&) noexcept = default; | |
template<typename IndexType, size_t N> | |
constexpr mapping(const std::array<IndexType, N>& dynamic_extents) noexcept : base_type(dynamic_extents) { | |
size_t stride = 1; | |
for(size_t i = 0; i < base_type::rank(); ++i) { | |
size_t j = base_type::rank() - 1 - i; | |
this->strides[j] = stride; | |
stride *= this->dimensions.extent(j); | |
} | |
} | |
template<typename... IndexType> | |
constexpr mapping(IndexType... DynamicExtents) noexcept : mapping(std::array<index_type, sizeof...(DynamicExtents)>{ DynamicExtents... }) { | |
} | |
constexpr mapping() noexcept : mapping(std::array<index_type, 0>{}) { | |
} | |
}; | |
}; | |
struct layout_left { | |
template<typename Dimensions> | |
struct mapping { | |
using base_type = mapping_base<Dimensions>; | |
using index_type = typename base_type::index_type; | |
constexpr mapping(const mapping&) noexcept = default; | |
constexpr mapping(mapping&&) noexcept = default; | |
~mapping() noexcept = default; | |
mapping& operator=(const mapping&) noexcept = default; | |
mapping& operator=(mapping&&) noexcept = default; | |
template<typename IndexType, size_t N> | |
constexpr mapping(const std::array<IndexType, N>& dynamic_extents) noexcept : base_type(dynamic_extents) { | |
size_t stride = 1; | |
for(size_t i = 0; i < base_type::rank(); ++i) { | |
this->strides[i] = stride; | |
stride *= this->dimensions.extent(i); | |
} | |
} | |
template<typename... IndexType> | |
constexpr mapping(IndexType... DynamicExtents) noexcept : mapping(std::array<index_type, sizeof...(DynamicExtents)>{ DynamicExtents... }) { | |
} | |
constexpr mapping() noexcept : mapping(std::array<index_type, 0>{}) { | |
} | |
}; | |
}; | |
template<size_t... Strides> | |
struct layout_strided { | |
template<typename Dimensions> | |
struct mapping { | |
// TODO this is rather ill-defined. For padded mdarrays the stride should be *bytes*, not *elements* | |
// but the other layouts operate in *elements* not *bytes*. | |
}; | |
}; | |
template<typename M, typename = void> | |
struct is_mapping : std::false_type { | |
}; | |
template<typename M> | |
struct is_mapping<M, std::void_t< | |
typename M::index_type, | |
decltype(M::rank()), | |
decltype(M::rank_dynamic()), | |
decltype(M::static_extent(std::declval<size_t>())), | |
decltype(std::declval<M>().extent(std::declval<size_t>())), | |
decltype(std::declval<M>().span_size()), | |
decltype(std::declval<M>().stride(std::declval<size_t>())), | |
decltype(M::is_always_unique), | |
decltype(M::is_always_contiguous), | |
decltype(M::is_always_strided), | |
decltype(std::declval<M>().is_unique()), | |
decltype(std::declval<M>().is_contiguous()), | |
decltype(std::declval<M>().is_strided()) | |
> > : std::true_type { | |
}; | |
template<typename M> | |
constexpr bool is_mapping_v = is_mapping<M>::value; | |
template<typename L, typename = void> | |
struct is_layout : std::false_type { | |
}; | |
template<typename L> | |
struct is_layout<L, std::void_t< | |
typename L::template mapping<extents<1>> | |
> > : std::true_type { | |
}; | |
template<typename L> | |
constexpr bool is_layout_v = is_layout<L>::value; | |
//template<typename T, typename Dimensions, typename Layout = layout_right> | |
template<typename T, typename... Properties> | |
struct mdspan { | |
using extents_type = first_matching_type_t<is_extent, Properties...>; | |
using layout_type = first_matching_type_t<is_layout, Properties..., layout_right>; | |
//using extents_type = Dimensions; | |
//using layout_type = Layout; | |
using mapping_type = typename layout_type::template mapping<extents_type>; | |
static_assert(is_extent_v<extents_type>, "Dimensions are ill-formed"); | |
static_assert(is_layout_v<layout_type>, "Layout is ill-formed"); | |
static_assert(is_mapping_v<mapping_type>, "Mapping is ill-formed"); | |
using element_type = typename std::remove_all_extents_t<T>; | |
using value_type = typename std::remove_cv_t<element_type>; | |
using index_type = ptrdiff_t; | |
using difference_type = ptrdiff_t; | |
using pointer = element_type*; | |
using reference = element_type&; | |
constexpr mdspan() noexcept = default; | |
constexpr mdspan(mdspan&&) noexcept = default; | |
constexpr mdspan(mdspan const&) noexcept = default; | |
~mdspan() noexcept = default; | |
mdspan& operator=(mdspan&&) noexcept = default; | |
mdspan& operator=(mdspan const&) noexcept = default; | |
//template <typename R, typename RDimensions, typename RLayout> | |
//constexpr mdspan(mdspan<R, RDimensions, RLayout> const&) noexcept; | |
//template <typename R, typename RDimensions, typename RLayout> | |
//mdspan& operator=(mdspan<R, RDimensions, RLayout> const&) noexcept; | |
//constexpr mdspan(nullptr_t) noexcept; | |
template<typename... IndexType> | |
explicit constexpr mdspan(pointer elts, IndexType... DynamicExtents) noexcept : elements(elts), mapping(DynamicExtents...) { | |
} | |
template<typename... IndexType> | |
explicit constexpr mdspan(span<element_type> elts, IndexType... DynamicExtents) noexcept : elements(elts.data()), mapping(DynamicExtents...) { | |
} | |
template<typename IndexType, size_t N> | |
explicit constexpr mdspan(pointer elts, const std::array<IndexType, N>& dynamic_extents) noexcept : elements(elts), mapping(dynamic_extents) { | |
} | |
template<typename IndexType, size_t N> | |
explicit constexpr mdspan(span<element_type> elts, const std::array<IndexType, N>& dynamic_extents) noexcept : elements(elts.data()), mapping(dynamic_extents) { | |
} | |
constexpr reference operator[](index_type idx) const noexcept { | |
return elements[mapping(idx)]; | |
} | |
template<typename IndexType, size_t N> | |
constexpr reference operator[](const std::array<IndexType, N>& indices) { | |
return this->operator()(indices); | |
} | |
template<typename... IndexType> | |
constexpr reference operator()(IndexType... indices) const noexcept { | |
return this->operator()(std::array<index_type, sizeof...(IndexType)> { static_cast<index_type>(indices)... }); | |
} | |
template<typename IndexType, size_t N> | |
constexpr reference operator()(const std::array<IndexType, N>& indices) const noexcept { | |
return elements[mapping(indices)]; | |
} | |
static constexpr int rank() noexcept { | |
return mapping_type::rank(); | |
} | |
static constexpr int rank_dynamic() noexcept { | |
return mapping_type::rank_dynamic(); | |
} | |
static constexpr index_type static_extent(size_t e) noexcept { | |
return mapping_type::static_extent(e); | |
} | |
constexpr index_type extent(size_t e) const noexcept { | |
return mapping.extent(e); | |
} | |
constexpr index_type size() const noexcept { | |
return mapping.size(); | |
} | |
constexpr md::span<element_type> span() const noexcept { | |
return md::span<element_type>(elements, size()); | |
} | |
template<typename... IndexType> | |
static constexpr index_type required_span_size(IndexType... DynamicExtents) { | |
return required_span_size(std::array<index_type, sizeof...(IndexType)>{ DynamicExtents...}); | |
} | |
template<typename IndexType, size_t N> | |
static constexpr index_type required_span_size(const std::array<IndexType, N>& dynamic_extents) { | |
return mapping_type(dynamic_extents).span_size(); | |
} | |
static constexpr bool is_always_unique = mapping_type::is_always_unique; | |
static constexpr bool is_always_contiguous = mapping_type::is_always_contiguous; | |
static constexpr bool is_always_strided = mapping_type::is_always_strided; | |
constexpr bool is_unique() const { | |
return mapping.is_unique(); | |
} | |
constexpr bool is_contiguous() const { | |
return mapping.is_contiguous(); | |
} | |
constexpr bool is_strided() const { | |
return mapping.is_strided(); | |
} | |
constexpr index_type stride(size_t r) const { | |
return mapping.stride(r); | |
} | |
private: | |
pointer elements; | |
mapping_type mapping; | |
}; | |
} | |
int main() | |
{ | |
using extent_type = md::extents<-1, -1, -1, -1>; | |
using span_type = md::mdspan<double, extent_type, md::layout_right>; | |
std::unique_ptr<double[]> raw_array_1 = std::make_unique<double[]>(span_type::required_span_size(4, 2, 5, 4)); | |
std::unique_ptr<double[]> raw_array_2 = std::make_unique<double[]>(span_type::required_span_size(4, 2, 5, 4)); | |
span_type mda_1(raw_array_1.get(), 4, 2, 5, 4); | |
span_type mda_2(raw_array_2.get(), 4, 2, 5, 4); | |
for(int i = 0; i < 4; ++i) { | |
for(int j = 0; j < 2; ++j) { | |
for(int k = 0; k < 5; ++k) { | |
for(int l = 0; l < 4; ++l) { | |
mda_1(i, j, k, l) = i + j + k + l; | |
mda_2(i, j, k, l) = i * j * k * l; | |
} | |
} | |
} | |
} | |
std::unique_ptr<double[]> raw_array_3 = std::make_unique<double[]>(span_type::required_span_size(4, 2, 5, 4)); | |
span_type mda_3(raw_array_3.get(), 4, 2, 5, 4); | |
for(int i = 0; i < 4; ++i) { | |
for(int j = 0; j < 2; ++j) { | |
for(int k = 0; k < 5; ++k) { | |
for(int l = 0; l < 4; ++l) { | |
mda_3(i, j, k, l) = mda_1(i, j, k, l) * mda_2(i, j, k, l); | |
} | |
} | |
} | |
} | |
for(int i = 0; i < 4; ++i) { | |
for(int j = 0; j < 2; ++j) { | |
for(int k = 0; k < 5; ++k) { | |
for(int l = 0; l < 4; ++l) { | |
std::cout << i << " " << j << " " << k << " " << l << ": " << mda_3(i, j, k, l) << std::endl; | |
} | |
} | |
} | |
} | |
for(int i = 0; i < 4; ++i) { | |
for(int j = 0; j < 2; ++j) { | |
for(int k = 0; k < 5; ++k) { | |
for(int l = 0; l < 4; ++l) { | |
mda_3(i, j, k, l) = i * j * k * l; | |
} | |
} | |
} | |
} | |
double raw_array[4][2][5][4]; | |
span_type mda_4(&raw_array[0][0][0][0], 4, 2, 5, 4); | |
for(int i = 0; i < 4; ++i) { | |
for(int j = 0; j < 2; ++j) { | |
for(int k = 0; k < 5; ++k) { | |
for(int l = 0; l < 4; ++l) { | |
raw_array[i][j][k][l] = i * j * k * l; | |
} | |
} | |
} | |
} | |
for(int i = 0; i < 4; ++i) { | |
for(int j = 0; j < 2; ++j) { | |
for(int k = 0; k < 5; ++k) { | |
for(int l = 0; l < 4; ++l) { | |
std::cout << i << " " << j << " " << k << " " << l << ": " << &raw_array[i][j][k][l] << " " << &mda_4(i, j, k, l) << std::endl; | |
} | |
} | |
} | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment