Skip to content

Instantly share code, notes, and snippets.

@amitsingh19975
Last active June 3, 2019 14:36
Show Gist options
  • Save amitsingh19975/ea6e978f8ca1d44d58aecedc13a52d87 to your computer and use it in GitHub Desktop.
Save amitsingh19975/ea6e978f8ca1d44d58aecedc13a52d87 to your computer and use it in GitHub Desktop.
//
// Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
//
// Distributed under the Boost Software License, Version 1.0. (See
// accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// The authors gratefully acknowledge the support of
// Fraunhofer IOSB, Ettlingen, Germany
//
#ifndef BOOST_NUMERIC_UBLAS_TENSOR_STATIC_EXTENTS_HPP
#define BOOST_NUMERIC_UBLAS_TENSOR_STATIC_EXTENTS_HPP
#include <array>
#include <boost/numeric/ublas/tensor/extents.hpp>
#include <boost/numeric/ublas/tensor/extents_helper.hpp>
#include <initializer_list>
#include <vector>
namespace boost::numeric::ublas {
/** @brief Template class for storing tensor extents for compile time.
*
* @code static_extents<4,1,2,3,4> t @endcode
* @tparam R rank of extents of type ptrdiff_t
* @tparam E parameter pack of extents
*
*/
template <ptrdiff_t R, ptrdiff_t... E>
struct static_extents
: detail::basic_extents_impl<0, detail::make_basic_shape_t<R, E...>> {
using base_type = ptrdiff_t;
using value_type = base_type;
using const_reference = base_type const &;
using reference = base_type &;
using size_type = base_type;
using const_pointer = base_type *const;
/** @brief Forward declaration of template class static_stride for making
* friend class.
*
* @tparam Extents type of this class
* @tparam Layout
*
*/
template <class Extents, class Layout> friend struct static_strides;
//@returns the rank of static_extents
static constexpr auto size() noexcept { return impl::Rank; }
//@returns the rank of static_extents
static constexpr auto rank() noexcept { return impl::Rank; }
//@returns the dynamic rank of static_extents
static constexpr auto dyanmic_rank() noexcept { return impl::DynamicRank; }
/**
* @param k pos of extent
* @returns the element at given pos
*/
constexpr auto at(size_type k) const noexcept { return impl::at(k); }
/**
* @param k pos of extent
* @returns Returns the number of elements a tensor holds with this from k
* position ownwards
*/
constexpr auto product(size_type k) const noexcept {
return impl::product(k);
}
/**
* @param k pos of extent
* @Returns the number of elements a tensor holds with this
*/
constexpr auto product() const noexcept { return impl::product(); }
// default constructor
constexpr static_extents() = default;
// default copy constructor
constexpr static_extents(static_extents const &other) = default;
// default assign constructor
constexpr static_extents &operator=(static_extents const &other) = default;
/** @brief assigns the extents to dynamic extents using parameter pack
*
* @code static_extents<2> e( 2,3 ); @endcode
*
* @tparam IndexType
*
* @param DynamicExtents parameter pack of extents
*
* @note number of extents should be equal to dynamic rank
*/
template <class... IndexType>
constexpr static_extents(IndexType... DynamicExtents)
: impl(DynamicExtents...) {}
/** @brief assigns the extents to dynamic extents using initializer_list
*
* @code static_extents<2> e = { 2, 3}; @endcode
*
* @tparam IndexType
*
* @param li of type initializer_list which constains the extents
*
* @note number of extents should be equal to dynamic rank
*/
template <class IndexType>
constexpr static_extents(std::initializer_list<IndexType> li)
: impl(li.begin(), li.end(), detail::iterator_tag{}) {}
/** @brief assigns the extents to dynamic extents using iterator
*
* @code static_extents<2> e( a.begin(), a.end() ); @endcode
*
* @tparam I of type input iterator and valur_type should be integral
*
* @param begin start of iterator
*
* @param end end of iterator
*
* @note number of extents should be equal to dynamic rank
*
*/
template <class I>
constexpr static_extents(I begin, I end)
: impl(begin, end, detail::iterator_tag_t<I>{}) {}
/** @brief Returns true if this has a scalar shape
*
* @returns true if (1,1,[1,...,1])
*/
constexpr bool is_scalar() const noexcept {
constexpr auto arr = to_array();
return size() != 0 && std::all_of(arr.begin(), arr.end(),
[](auto const &a) { return a == 1; });
}
/** @brief Returns true if this has a vector shape
*
* @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1
*/
constexpr bool is_vector() const noexcept {
if (size() == 0) {
return false;
}
if (size() == 1) {
return at(0) > 1;
}
auto arr = to_array();
auto greater_one = [](auto const &a) { return a > 1; };
auto equal_one = [](auto const &a) { return a == 1; };
return std::any_of(arr.begin(), arr.begin() + 2, greater_one) &&
std::any_of(arr.begin(), arr.begin() + 2, equal_one) &&
std::all_of(arr.begin() + 2, arr.end(), equal_one);
}
/** @brief Returns true if this has a matrix shape
*
* @returns true if (m,n,[1,...,1]) with m > 1 and n > 1
*/
constexpr bool is_matrix() const noexcept {
if (size() < 2) {
return false;
}
auto arr = to_array();
auto greater_one = [](auto const &a) { return a > 1; };
auto equal_one = [](auto const &a) { return a == 1; };
return std::all_of(arr.begin(), arr.begin() + 2, greater_one) &&
std::all_of(arr.begin() + 2, arr.end(), equal_one);
}
/** @brief Returns true if this is has a tensor shape
*
* @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix()
*/
constexpr bool is_tensor() const noexcept {
if (size() < 3) {
return false;
}
auto arr = to_array();
auto greater_one = [](auto const &a) { return a > 1; };
return std::any_of(arr.begin() + 2, arr.end(), greater_one);
}
/** @brief Returns the std::vector containing extents */
auto to_vector() const noexcept {
std::vector<base_type> temp(R);
for (auto i = 0u; i < temp.size(); i++) {
temp[i] = at(i);
}
return temp;
}
/** @brief Returns the std::array containing extents */
constexpr auto to_array() const noexcept {
std::array<base_type, R> temp;
for (auto i = 0u; i < temp.size(); i++) {
temp[i] = at(i);
}
return temp;
}
/** @brief Checks if extents is empty or not
*
* @returns true if rank is 0 else false
*
*/
constexpr auto empty() const noexcept { return size() == 0; }
/** @brief Returns true if size > 1 and all elements > 0 */
constexpr bool valid() const noexcept {
auto arr = to_array();
return size() > 1 &&
std::none_of(arr.begin(), arr.end(),
[](auto const &a) { return a == ptrdiff_t(0); });
}
/** @brief Returns true if both extents are equal else false */
template <ptrdiff_t rhs_dims, ptrdiff_t... rhs>
constexpr auto operator==(static_extents<rhs_dims, rhs...> const &other) const
noexcept {
if (size() != other.size()) {
return false;
}
for (auto i = 0u; i < size(); i++) {
if (other.at(i) != at(i))
return false;
}
return true;
}
/** @brief Returns false if both extents are equal else true */
template <ptrdiff_t rhs_dims, ptrdiff_t... rhs>
constexpr auto operator!=(static_extents<rhs_dims, rhs...> const &other) const
noexcept {
return !(*this == other);
}
/** @brief Eliminates singleton dimensions when size > 2
*
* squeeze { 1,1} -> { 1,1}
* squeeze { 2,1} -> { 2,1}
* squeeze { 1,2} -> { 1,2}
*
* squeeze {1,2,3} -> { 2,3}
* squeeze {2,1,3} -> { 2,3}
* squeeze {1,3,1} -> { 1,3}
*
*/
auto squeeze() const noexcept {
auto arr = to_vector();
basic_extents<size_t> e{arr.cbegin(), arr.cend()};
return e.squeeze();
}
~static_extents() = default;
private:
using impl =
detail::basic_extents_impl<0, detail::make_basic_shape_t<R, E...>>;
};
namespace detail {
/** @brief Checks if the extents is dynamic or static
*
* @tparam E of type basic_extents or static_extents
*
*/
template <class E>
struct is_dynamic_extents : std::integral_constant<bool, false> {};
/** @brief Partial Specialization of is_dynamic_extents with basic_extens
*
* @tparam T of any integer type
*
*/
template <class T>
struct is_dynamic_extents<basic_extents<T>>
: std::integral_constant<bool, true> {};
/** @brief Partial Specialization of is_dynamic_extents with static_extents
*
* @tparam R rank of ptrdiff_t
*
* @tparam E parameter pack of extents
*
*/
template <ptrdiff_t R, ptrdiff_t... E>
struct is_dynamic_extents<static_extents<R, E...>> {
static constexpr bool value =
static_extents<R, E...>::dyanmic_rank() != 0 &&
is_dynamic_basic_shape_v<make_basic_shape_t<R, E...>>;
};
} // namespace detail
namespace framework {
/** @brief type alias of basic_extents or static_extents depending on Rank
*
* @tparam R rank of extents
*
* @tparam E contains the extents as a parameter pack
*
*/
template <ptrdiff_t R, ptrdiff_t... E>
using shape_t =
std::conditional_t<(R < 0), basic_extents<size_t>, static_extents<R, E...>>;
} // namespace framework
} // namespace boost::numeric::ublas
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment