Skip to content

Instantly share code, notes, and snippets.

@DrPizza
Created November 23, 2017 06:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DrPizza/39f90bd2324f0100f487eb6e27870d29 to your computer and use it in GitHub Desktop.
Save DrPizza/39f90bd2324f0100f487eb6e27870d29 to your computer and use it in GitHub Desktop.
//http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0009r4.html
#include <utility>
#include <array>
#include <algorithm>
#include <functional>
#include <type_traits>
#include <memory>
#include <exception>
#include <stdexcept>
#include <cstddef>
#include <numeric>
#include <iostream>
#include <gsl/gsl>
namespace md {
template<typename T, ptrdiff_t E = gsl::dynamic_extent>
using span = gsl::span<T, E>;
constexpr ptrdiff_t dynamic_extent = gsl::dynamic_extent;
namespace {
template<ptrdiff_t... Ns>
constexpr size_t count_dynamics() {
size_t count = 0;
using swallow = size_t[];
static_cast<void>(swallow{ (count += (Ns == dynamic_extent))... });
return count;
}
template<template<typename...> typename Pred, typename... Types>
constexpr std::size_t find_first_type_of() {
constexpr bool matches[] = { Pred<Types>::value... };
for(std::size_t i = 0; i < sizeof...(Types); ++i) {
if(matches[i]) {
return i;
}
}
return ~static_cast<std::size_t>(0);
}
template<template<typename...> typename Pred, typename... Types>
constexpr std::size_t find_first_type_of_v = find_first_type_of<Pred, Types...>();
template<std::size_t I, typename... Types>
struct select_nth;
template<typename T, typename... Ts>
struct select_nth<0, T, Ts...> {
using type = T;
};
template<std::size_t I, typename T, typename... Ts>
struct select_nth<I, T, Ts...> : select_nth<I - 1, Ts...> {
static_assert(I <= sizeof...(Ts), "bad index");
};
template<size_t I, typename... Ts>
using select_nth_t = typename select_nth<I, Ts...>::type;
template<template<typename...> typename Predicate, typename... Ts>
struct first_matching_type {
using type = select_nth_t<find_first_type_of_v<Predicate, Ts...>, Ts...>;
};
template<template<typename...> typename Predicate, typename... Ts>
using first_matching_type_t = typename first_matching_type<Predicate, Ts...>::type;
}
template<typename It1, typename It2, typename T, typename Fn1, typename Fn2>
constexpr T inner_product(It1 first1, It1 end1, It2 first2, T val, Fn1 f1, Fn2 f2) {
for(; first1 != end1; static_cast<void>(++first1), static_cast<void>(++first2)) {
val = f1(val, f2(*first1, *first2));
}
return val;
}
template<typename It, typename T, typename Fn>
constexpr T accumulate(It first, It end, T val, Fn f) {
for(; first != end; ++first) {
val = f(val, *first);
}
return val;
}
template<ptrdiff_t... Extents>
struct extents {
using index_type = ptrdiff_t;
using array_type = index_type[];
using my_type = extents<Extents...>;
constexpr extents() noexcept : values{ Extents... } {
}
constexpr extents(const extents&) = default;
constexpr extents(extents&&) = default;
~extents() = default;
extents& operator=(const extents&) noexcept = default;
extents& operator=(extents&&) noexcept = default;
static constexpr size_t rank() noexcept {
return sizeof...(Extents);
}
static constexpr size_t rank_dynamic() noexcept {
return count_dynamics<Extents...>();
}
static constexpr index_type static_extent(size_t e) noexcept {
if(e >= rank()) {
return 1;
} else {
return array_type{ Extents... }[e];
}
}
constexpr index_type extent(size_t e) const noexcept {
if(e >= rank()) {
return 1;
} else {
return values[e];
}
}
constexpr index_type size() const noexcept {
return md::accumulate(values.cbegin(), values.cend(), static_cast<index_type>(1), std::multiplies<>());
}
template<typename IndexType, size_t N>
constexpr extents(const std::array<IndexType, N>& dynamic_extents) noexcept : values{ Extents... } {
static_assert(N == rank_dynamic(), "wrong number of dynamic extents provided");
for(size_t i = 0, j = 0; i < rank(); ++i) {
if(values[i] == dynamic_extent) {
values[i] = dynamic_extents[j++];
}
}
}
template<typename... IndexType>
constexpr extents(IndexType... DynamicExtents) noexcept : extents(std::array<index_type, sizeof...(IndexType)> { DynamicExtents... }) {
}
private:
std::array<index_type, my_type::rank()> values;
};
template<typename E, typename = void>
struct is_extent : std::false_type {
};
template<typename E>
struct is_extent<E, std::void_t<
typename E::index_type,
decltype(E::rank()),
decltype(E::rank_dynamic()),
decltype(E::static_extent(std::declval<size_t>())),
decltype(std::declval<E>().extent(std::declval<size_t>())),
decltype(std::declval<E>().size())
> > : std::true_type {
};
template<typename E>
constexpr bool is_extent_v = is_extent<E>::value;
namespace {
template<typename Dimensions>
struct mapping_base {
using index_type = ptrdiff_t;
using array_type = index_type[];
using extent_type = Dimensions;
constexpr mapping_base() noexcept : strides{} {}
constexpr mapping_base(const mapping_base&) noexcept = default;
constexpr mapping_base(mapping_base&&) noexcept = default;
~mapping_base() noexcept = default;
mapping_base& operator=(const mapping_base&) noexcept = default;
mapping_base& operator=(mapping_base&&) noexcept = default;
template<typename IndexType, size_t N>
constexpr mapping_base(const std::array<IndexType, N>& dynamic_extents) noexcept : dimensions(dynamic_extents), strides{} {
}
template<typename... IndexType>
constexpr mapping_base(IndexType... DynamicExtents) noexcept : mapping_base(std::array<index_type, sizeof...(DynamicExtents)>{ DynamicExtents... }) {
}
static constexpr size_t rank() noexcept {
return extent_type::rank();
}
static constexpr size_t rank_dynamic() noexcept {
return extent_type::rank_dynamic();
}
static constexpr index_type static_extent(size_t i) noexcept {
return extent_type::static_extent(i);
}
constexpr index_type extent(size_t i) const noexcept {
return dimensions.extent(i);
}
constexpr index_type span_size() const noexcept {
return dimensions.size();
}
static constexpr bool is_always_unique = true;
static constexpr bool is_always_contiguous = true;
static constexpr bool is_always_strided = true;
constexpr bool is_unique() const noexcept {
return true;
}
constexpr bool is_contiguous() const noexcept {
return true;
}
constexpr bool is_strided() noexcept {
return true;
}
constexpr index_type stride(size_t r) const noexcept {
if(r >= rank()) {
return 0;
}
return strides[r];
}
template<typename IndexType, size_t N>
constexpr index_type operator()(const std::array<IndexType, N>& idxes) const noexcept {
static_assert(N == rank(), "wrong number of indices passed");
return md::inner_product(strides.begin(), strides.end(), idxes.begin(), static_cast<index_type>(0), std::plus<>(), std::multiplies<>());
}
template<typename... IndexType >
constexpr index_type operator()(IndexType... indices) const noexcept {
return operator()(std::array<index_type, sizeof...(IndexType)>{ static_cast<index_type>(indices)... });
}
protected:
extent_type dimensions;
std::array<size_t, extent_type::rank()> strides;
};
}
struct layout_right {
template<typename Dimensions>
struct mapping : mapping_base<Dimensions> {
using base_type = mapping_base<Dimensions>;
using index_type = typename base_type::index_type;
constexpr mapping(const mapping&) noexcept = default;
constexpr mapping(mapping&&) noexcept = default;
~mapping() noexcept = default;
mapping& operator=(const mapping&) noexcept = default;
mapping& operator=(mapping&&) noexcept = default;
template<typename IndexType, size_t N>
constexpr mapping(const std::array<IndexType, N>& dynamic_extents) noexcept : base_type(dynamic_extents) {
size_t stride = 1;
for(size_t i = 0; i < base_type::rank(); ++i) {
size_t j = base_type::rank() - 1 - i;
this->strides[j] = stride;
stride *= this->dimensions.extent(j);
}
}
template<typename... IndexType>
constexpr mapping(IndexType... DynamicExtents) noexcept : mapping(std::array<index_type, sizeof...(DynamicExtents)>{ DynamicExtents... }) {
}
constexpr mapping() noexcept : mapping(std::array<index_type, 0>{}) {
}
};
};
struct layout_left {
template<typename Dimensions>
struct mapping {
using base_type = mapping_base<Dimensions>;
using index_type = typename base_type::index_type;
constexpr mapping(const mapping&) noexcept = default;
constexpr mapping(mapping&&) noexcept = default;
~mapping() noexcept = default;
mapping& operator=(const mapping&) noexcept = default;
mapping& operator=(mapping&&) noexcept = default;
template<typename IndexType, size_t N>
constexpr mapping(const std::array<IndexType, N>& dynamic_extents) noexcept : base_type(dynamic_extents) {
size_t stride = 1;
for(size_t i = 0; i < base_type::rank(); ++i) {
this->strides[i] = stride;
stride *= this->dimensions.extent(i);
}
}
template<typename... IndexType>
constexpr mapping(IndexType... DynamicExtents) noexcept : mapping(std::array<index_type, sizeof...(DynamicExtents)>{ DynamicExtents... }) {
}
constexpr mapping() noexcept : mapping(std::array<index_type, 0>{}) {
}
};
};
template<size_t... Strides>
struct layout_strided {
template<typename Dimensions>
struct mapping {
// TODO this is rather ill-defined. For padded mdarrays the stride should be *bytes*, not *elements*
// but the other layouts operate in *elements* not *bytes*.
};
};
template<typename M, typename = void>
struct is_mapping : std::false_type {
};
template<typename M>
struct is_mapping<M, std::void_t<
typename M::index_type,
decltype(M::rank()),
decltype(M::rank_dynamic()),
decltype(M::static_extent(std::declval<size_t>())),
decltype(std::declval<M>().extent(std::declval<size_t>())),
decltype(std::declval<M>().span_size()),
decltype(std::declval<M>().stride(std::declval<size_t>())),
decltype(M::is_always_unique),
decltype(M::is_always_contiguous),
decltype(M::is_always_strided),
decltype(std::declval<M>().is_unique()),
decltype(std::declval<M>().is_contiguous()),
decltype(std::declval<M>().is_strided())
> > : std::true_type {
};
template<typename M>
constexpr bool is_mapping_v = is_mapping<M>::value;
template<typename L, typename = void>
struct is_layout : std::false_type {
};
template<typename L>
struct is_layout<L, std::void_t<
typename L::template mapping<extents<1>>
> > : std::true_type {
};
template<typename L>
constexpr bool is_layout_v = is_layout<L>::value;
//template<typename T, typename Dimensions, typename Layout = layout_right>
template<typename T, typename... Properties>
struct mdspan {
using extents_type = first_matching_type_t<is_extent, Properties...>;
using layout_type = first_matching_type_t<is_layout, Properties..., layout_right>;
//using extents_type = Dimensions;
//using layout_type = Layout;
using mapping_type = typename layout_type::template mapping<extents_type>;
static_assert(is_extent_v<extents_type>, "Dimensions are ill-formed");
static_assert(is_layout_v<layout_type>, "Layout is ill-formed");
static_assert(is_mapping_v<mapping_type>, "Mapping is ill-formed");
using element_type = typename std::remove_all_extents_t<T>;
using value_type = typename std::remove_cv_t<element_type>;
using index_type = ptrdiff_t;
using difference_type = ptrdiff_t;
using pointer = element_type*;
using reference = element_type&;
constexpr mdspan() noexcept = default;
constexpr mdspan(mdspan&&) noexcept = default;
constexpr mdspan(mdspan const&) noexcept = default;
~mdspan() noexcept = default;
mdspan& operator=(mdspan&&) noexcept = default;
mdspan& operator=(mdspan const&) noexcept = default;
//template <typename R, typename RDimensions, typename RLayout>
//constexpr mdspan(mdspan<R, RDimensions, RLayout> const&) noexcept;
//template <typename R, typename RDimensions, typename RLayout>
//mdspan& operator=(mdspan<R, RDimensions, RLayout> const&) noexcept;
//constexpr mdspan(nullptr_t) noexcept;
template<typename... IndexType>
explicit constexpr mdspan(pointer elts, IndexType... DynamicExtents) noexcept : elements(elts), mapping(DynamicExtents...) {
}
template<typename... IndexType>
explicit constexpr mdspan(span<element_type> elts, IndexType... DynamicExtents) noexcept : elements(elts.data()), mapping(DynamicExtents...) {
}
template<typename IndexType, size_t N>
explicit constexpr mdspan(pointer elts, const std::array<IndexType, N>& dynamic_extents) noexcept : elements(elts), mapping(dynamic_extents) {
}
template<typename IndexType, size_t N>
explicit constexpr mdspan(span<element_type> elts, const std::array<IndexType, N>& dynamic_extents) noexcept : elements(elts.data()), mapping(dynamic_extents) {
}
constexpr reference operator[](index_type idx) const noexcept {
return elements[mapping(idx)];
}
template<typename IndexType, size_t N>
constexpr reference operator[](const std::array<IndexType, N>& indices) {
return this->operator()(indices);
}
template<typename... IndexType>
constexpr reference operator()(IndexType... indices) const noexcept {
return this->operator()(std::array<index_type, sizeof...(IndexType)> { static_cast<index_type>(indices)... });
}
template<typename IndexType, size_t N>
constexpr reference operator()(const std::array<IndexType, N>& indices) const noexcept {
return elements[mapping(indices)];
}
static constexpr int rank() noexcept {
return mapping_type::rank();
}
static constexpr int rank_dynamic() noexcept {
return mapping_type::rank_dynamic();
}
static constexpr index_type static_extent(size_t e) noexcept {
return mapping_type::static_extent(e);
}
constexpr index_type extent(size_t e) const noexcept {
return mapping.extent(e);
}
constexpr index_type size() const noexcept {
return mapping.size();
}
constexpr md::span<element_type> span() const noexcept {
return md::span<element_type>(elements, size());
}
template<typename... IndexType>
static constexpr index_type required_span_size(IndexType... DynamicExtents) {
return required_span_size(std::array<index_type, sizeof...(IndexType)>{ DynamicExtents...});
}
template<typename IndexType, size_t N>
static constexpr index_type required_span_size(const std::array<IndexType, N>& dynamic_extents) {
return mapping_type(dynamic_extents).span_size();
}
static constexpr bool is_always_unique = mapping_type::is_always_unique;
static constexpr bool is_always_contiguous = mapping_type::is_always_contiguous;
static constexpr bool is_always_strided = mapping_type::is_always_strided;
constexpr bool is_unique() const {
return mapping.is_unique();
}
constexpr bool is_contiguous() const {
return mapping.is_contiguous();
}
constexpr bool is_strided() const {
return mapping.is_strided();
}
constexpr index_type stride(size_t r) const {
return mapping.stride(r);
}
private:
pointer elements;
mapping_type mapping;
};
}
int main()
{
using extent_type = md::extents<-1, -1, -1, -1>;
using span_type = md::mdspan<double, extent_type, md::layout_right>;
std::unique_ptr<double[]> raw_array_1 = std::make_unique<double[]>(span_type::required_span_size(4, 2, 5, 4));
std::unique_ptr<double[]> raw_array_2 = std::make_unique<double[]>(span_type::required_span_size(4, 2, 5, 4));
span_type mda_1(raw_array_1.get(), 4, 2, 5, 4);
span_type mda_2(raw_array_2.get(), 4, 2, 5, 4);
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 2; ++j) {
for(int k = 0; k < 5; ++k) {
for(int l = 0; l < 4; ++l) {
mda_1(i, j, k, l) = i + j + k + l;
mda_2(i, j, k, l) = i * j * k * l;
}
}
}
}
std::unique_ptr<double[]> raw_array_3 = std::make_unique<double[]>(span_type::required_span_size(4, 2, 5, 4));
span_type mda_3(raw_array_3.get(), 4, 2, 5, 4);
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 2; ++j) {
for(int k = 0; k < 5; ++k) {
for(int l = 0; l < 4; ++l) {
mda_3(i, j, k, l) = mda_1(i, j, k, l) * mda_2(i, j, k, l);
}
}
}
}
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 2; ++j) {
for(int k = 0; k < 5; ++k) {
for(int l = 0; l < 4; ++l) {
std::cout << i << " " << j << " " << k << " " << l << ": " << mda_3(i, j, k, l) << std::endl;
}
}
}
}
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 2; ++j) {
for(int k = 0; k < 5; ++k) {
for(int l = 0; l < 4; ++l) {
mda_3(i, j, k, l) = i * j * k * l;
}
}
}
}
double raw_array[4][2][5][4];
span_type mda_4(&raw_array[0][0][0][0], 4, 2, 5, 4);
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 2; ++j) {
for(int k = 0; k < 5; ++k) {
for(int l = 0; l < 4; ++l) {
raw_array[i][j][k][l] = i * j * k * l;
}
}
}
}
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 2; ++j) {
for(int k = 0; k < 5; ++k) {
for(int l = 0; l < 4; ++l) {
std::cout << i << " " << j << " " << k << " " << l << ": " << &raw_array[i][j][k][l] << " " << &mda_4(i, j, k, l) << std::endl;
}
}
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment