Last active
June 3, 2019 14:36
-
-
Save amitsingh19975/ea6e978f8ca1d44d58aecedc13a52d87 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com | |
// | |
// Distributed under the Boost Software License, Version 1.0. (See | |
// accompanying file LICENSE_1_0.txt or copy at | |
// http://www.boost.org/LICENSE_1_0.txt) | |
// | |
// The authors gratefully acknowledge the support of | |
// Fraunhofer IOSB, Ettlingen, Germany | |
// | |
#ifndef BOOST_NUMERIC_UBLAS_TENSOR_STATIC_EXTENTS_HPP | |
#define BOOST_NUMERIC_UBLAS_TENSOR_STATIC_EXTENTS_HPP | |
#include <array> | |
#include <boost/numeric/ublas/tensor/extents.hpp> | |
#include <boost/numeric/ublas/tensor/extents_helper.hpp> | |
#include <initializer_list> | |
#include <vector> | |
namespace boost::numeric::ublas { | |
/** @brief Template class for storing tensor extents for compile time. | |
* | |
* @code static_extents<4,1,2,3,4> t @endcode | |
* @tparam R rank of extents of type ptrdiff_t | |
* @tparam E parameter pack of extents | |
* | |
*/ | |
template <ptrdiff_t R, ptrdiff_t... E> | |
struct static_extents | |
: detail::basic_extents_impl<0, detail::make_basic_shape_t<R, E...>> { | |
using base_type = ptrdiff_t; | |
using value_type = base_type; | |
using const_reference = base_type const &; | |
using reference = base_type &; | |
using size_type = base_type; | |
using const_pointer = base_type *const; | |
/** @brief Forward declaration of template class static_stride for making | |
* friend class. | |
* | |
* @tparam Extents type of this class | |
* @tparam Layout | |
* | |
*/ | |
template <class Extents, class Layout> friend struct static_strides; | |
//@returns the rank of static_extents | |
static constexpr auto size() noexcept { return impl::Rank; } | |
//@returns the rank of static_extents | |
static constexpr auto rank() noexcept { return impl::Rank; } | |
//@returns the dynamic rank of static_extents | |
static constexpr auto dyanmic_rank() noexcept { return impl::DynamicRank; } | |
/** | |
* @param k pos of extent | |
* @returns the element at given pos | |
*/ | |
constexpr auto at(size_type k) const noexcept { return impl::at(k); } | |
/** | |
* @param k pos of extent | |
* @returns Returns the number of elements a tensor holds with this from k | |
* position ownwards | |
*/ | |
constexpr auto product(size_type k) const noexcept { | |
return impl::product(k); | |
} | |
/** | |
* @param k pos of extent | |
* @Returns the number of elements a tensor holds with this | |
*/ | |
constexpr auto product() const noexcept { return impl::product(); } | |
// default constructor | |
constexpr static_extents() = default; | |
// default copy constructor | |
constexpr static_extents(static_extents const &other) = default; | |
// default assign constructor | |
constexpr static_extents &operator=(static_extents const &other) = default; | |
/** @brief assigns the extents to dynamic extents using parameter pack | |
* | |
* @code static_extents<2> e( 2,3 ); @endcode | |
* | |
* @tparam IndexType | |
* | |
* @param DynamicExtents parameter pack of extents | |
* | |
* @note number of extents should be equal to dynamic rank | |
*/ | |
template <class... IndexType> | |
constexpr static_extents(IndexType... DynamicExtents) | |
: impl(DynamicExtents...) {} | |
/** @brief assigns the extents to dynamic extents using initializer_list | |
* | |
* @code static_extents<2> e = { 2, 3}; @endcode | |
* | |
* @tparam IndexType | |
* | |
* @param li of type initializer_list which constains the extents | |
* | |
* @note number of extents should be equal to dynamic rank | |
*/ | |
template <class IndexType> | |
constexpr static_extents(std::initializer_list<IndexType> li) | |
: impl(li.begin(), li.end(), detail::iterator_tag{}) {} | |
/** @brief assigns the extents to dynamic extents using iterator | |
* | |
* @code static_extents<2> e( a.begin(), a.end() ); @endcode | |
* | |
* @tparam I of type input iterator and valur_type should be integral | |
* | |
* @param begin start of iterator | |
* | |
* @param end end of iterator | |
* | |
* @note number of extents should be equal to dynamic rank | |
* | |
*/ | |
template <class I> | |
constexpr static_extents(I begin, I end) | |
: impl(begin, end, detail::iterator_tag_t<I>{}) {} | |
/** @brief Returns true if this has a scalar shape | |
* | |
* @returns true if (1,1,[1,...,1]) | |
*/ | |
constexpr bool is_scalar() const noexcept { | |
constexpr auto arr = to_array(); | |
return size() != 0 && std::all_of(arr.begin(), arr.end(), | |
[](auto const &a) { return a == 1; }); | |
} | |
/** @brief Returns true if this has a vector shape | |
* | |
* @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1 | |
*/ | |
constexpr bool is_vector() const noexcept { | |
if (size() == 0) { | |
return false; | |
} | |
if (size() == 1) { | |
return at(0) > 1; | |
} | |
auto arr = to_array(); | |
auto greater_one = [](auto const &a) { return a > 1; }; | |
auto equal_one = [](auto const &a) { return a == 1; }; | |
return std::any_of(arr.begin(), arr.begin() + 2, greater_one) && | |
std::any_of(arr.begin(), arr.begin() + 2, equal_one) && | |
std::all_of(arr.begin() + 2, arr.end(), equal_one); | |
} | |
/** @brief Returns true if this has a matrix shape | |
* | |
* @returns true if (m,n,[1,...,1]) with m > 1 and n > 1 | |
*/ | |
constexpr bool is_matrix() const noexcept { | |
if (size() < 2) { | |
return false; | |
} | |
auto arr = to_array(); | |
auto greater_one = [](auto const &a) { return a > 1; }; | |
auto equal_one = [](auto const &a) { return a == 1; }; | |
return std::all_of(arr.begin(), arr.begin() + 2, greater_one) && | |
std::all_of(arr.begin() + 2, arr.end(), equal_one); | |
} | |
/** @brief Returns true if this is has a tensor shape | |
* | |
* @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix() | |
*/ | |
constexpr bool is_tensor() const noexcept { | |
if (size() < 3) { | |
return false; | |
} | |
auto arr = to_array(); | |
auto greater_one = [](auto const &a) { return a > 1; }; | |
return std::any_of(arr.begin() + 2, arr.end(), greater_one); | |
} | |
/** @brief Returns the std::vector containing extents */ | |
auto to_vector() const noexcept { | |
std::vector<base_type> temp(R); | |
for (auto i = 0u; i < temp.size(); i++) { | |
temp[i] = at(i); | |
} | |
return temp; | |
} | |
/** @brief Returns the std::array containing extents */ | |
constexpr auto to_array() const noexcept { | |
std::array<base_type, R> temp; | |
for (auto i = 0u; i < temp.size(); i++) { | |
temp[i] = at(i); | |
} | |
return temp; | |
} | |
/** @brief Checks if extents is empty or not | |
* | |
* @returns true if rank is 0 else false | |
* | |
*/ | |
constexpr auto empty() const noexcept { return size() == 0; } | |
/** @brief Returns true if size > 1 and all elements > 0 */ | |
constexpr bool valid() const noexcept { | |
auto arr = to_array(); | |
return size() > 1 && | |
std::none_of(arr.begin(), arr.end(), | |
[](auto const &a) { return a == ptrdiff_t(0); }); | |
} | |
/** @brief Returns true if both extents are equal else false */ | |
template <ptrdiff_t rhs_dims, ptrdiff_t... rhs> | |
constexpr auto operator==(static_extents<rhs_dims, rhs...> const &other) const | |
noexcept { | |
if (size() != other.size()) { | |
return false; | |
} | |
for (auto i = 0u; i < size(); i++) { | |
if (other.at(i) != at(i)) | |
return false; | |
} | |
return true; | |
} | |
/** @brief Returns false if both extents are equal else true */ | |
template <ptrdiff_t rhs_dims, ptrdiff_t... rhs> | |
constexpr auto operator!=(static_extents<rhs_dims, rhs...> const &other) const | |
noexcept { | |
return !(*this == other); | |
} | |
/** @brief Eliminates singleton dimensions when size > 2 | |
* | |
* squeeze { 1,1} -> { 1,1} | |
* squeeze { 2,1} -> { 2,1} | |
* squeeze { 1,2} -> { 1,2} | |
* | |
* squeeze {1,2,3} -> { 2,3} | |
* squeeze {2,1,3} -> { 2,3} | |
* squeeze {1,3,1} -> { 1,3} | |
* | |
*/ | |
auto squeeze() const noexcept { | |
auto arr = to_vector(); | |
basic_extents<size_t> e{arr.cbegin(), arr.cend()}; | |
return e.squeeze(); | |
} | |
~static_extents() = default; | |
private: | |
using impl = | |
detail::basic_extents_impl<0, detail::make_basic_shape_t<R, E...>>; | |
}; | |
namespace detail { | |
/** @brief Checks if the extents is dynamic or static | |
* | |
* @tparam E of type basic_extents or static_extents | |
* | |
*/ | |
template <class E> | |
struct is_dynamic_extents : std::integral_constant<bool, false> {}; | |
/** @brief Partial Specialization of is_dynamic_extents with basic_extens | |
* | |
* @tparam T of any integer type | |
* | |
*/ | |
template <class T> | |
struct is_dynamic_extents<basic_extents<T>> | |
: std::integral_constant<bool, true> {}; | |
/** @brief Partial Specialization of is_dynamic_extents with static_extents | |
* | |
* @tparam R rank of ptrdiff_t | |
* | |
* @tparam E parameter pack of extents | |
* | |
*/ | |
template <ptrdiff_t R, ptrdiff_t... E> | |
struct is_dynamic_extents<static_extents<R, E...>> { | |
static constexpr bool value = | |
static_extents<R, E...>::dyanmic_rank() != 0 && | |
is_dynamic_basic_shape_v<make_basic_shape_t<R, E...>>; | |
}; | |
} // namespace detail | |
namespace framework { | |
/** @brief type alias of basic_extents or static_extents depending on Rank | |
* | |
* @tparam R rank of extents | |
* | |
* @tparam E contains the extents as a parameter pack | |
* | |
*/ | |
template <ptrdiff_t R, ptrdiff_t... E> | |
using shape_t = | |
std::conditional_t<(R < 0), basic_extents<size_t>, static_extents<R, E...>>; | |
} // namespace framework | |
} // namespace boost::numeric::ublas | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment