Created
June 15, 2020 09:36
-
-
Save amitsingh19975/244b98c3ff174dc4d6a547e6609716fc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <std::size_t q, typename TensorEngine1, typename TensorEngine2, std::size_t... Ss1, std::size_t... Ss2> | |
inline decltype(auto) prod(tensor_core< TensorEngine1 > const &a, tensor_core< TensorEngine2 > const &b, | |
std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>) | |
{ | |
using tensor_type = tensor_core< TensorEngine1 >; | |
using extents_type_1 = typename tensor_type::extents_type; | |
using extents_type_2 = typename tensor_core< TensorEngine2 >::extents_type; | |
using value_type = typename tensor_type::value_type; | |
using layout_type = typename tensor_type::layout_type; | |
using size_type = typename extents_type_1::size_type; | |
using array_type = typename tensor_type::array_type; | |
using phia = boost::mp11::mp_list_c<std::size_t, Ss1...>; | |
using phib = boost::mp11::mp_list_c<std::size_t, Ss2...>; | |
// static_assert( | |
// std::is_same_v< | |
// typename tensor_core<TensorEngine1>::resizable_tag, | |
// typename tensor_core<TensorEngine2>::resizable_tag | |
// > && | |
// std::is_same_v< | |
// typename tensor_core<TensorEngine1>::resizable_tag, | |
// storage_static_container_tag | |
// >, | |
// "error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, " | |
// "std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: " | |
// "Both tensor storage should have the same type of storage and both should be static storage" | |
// ); | |
static_assert( | |
is_static< typename tensor_core< TensorEngine1 >::extents_type >::value && | |
is_static< typename tensor_core< TensorEngine2 >::extents_type >::value, | |
"error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, " | |
"std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: " | |
"Both tensor should have static extents" | |
); | |
constexpr auto const pa = extents_type_1::size_; | |
constexpr auto const pb = extents_type_2::size_; | |
constexpr auto const sz_phia = sizeof...(Ss1); | |
constexpr auto const sz_phib = sizeof...(Ss2); | |
static_assert( | |
pa != 0ul, | |
"error in ublas::prod: order of left-hand side tensor must be greater than 0." | |
); | |
static_assert( | |
pb != 0ul, | |
"error in ublas::prod: order of right-hand side tensor must be greater than 0." | |
); | |
static_assert( | |
pa > q, | |
"error in ublas::prod: number of contraction dimensions cannot be greater than the order of the left-hand side tensor." | |
); | |
static_assert( | |
pb > q, | |
"error in ublas::prod: number of contraction dimensions cannot be greater than the order of the right-hand side tensor." | |
); | |
static_assert( | |
q == sz_phia, | |
"error in ublas::prod: permutation tuples must have the same length." | |
); | |
static_assert( | |
pa > sz_phia, | |
"error in ublas::prod: permutation tuple for the left-hand side tensor cannot be greater than the corresponding order." | |
); | |
static_assert( | |
pb > sz_phib, | |
"error in ublas::prod: permutation tuple for the right-hand side tensor cannot be greater than the corresponding order." | |
); | |
auto const& na = a.extents(); | |
auto const& nb = b.extents(); | |
detail::static_for<q>([&](auto iter){ | |
using namespace boost::mp11; | |
using lext = std::decay_t<decltype(na)>; | |
using rext = std::decay_t<decltype(nb)>; | |
using iter_type = decltype(iter); | |
using lph = phia; | |
using rph = phib; | |
static_assert( | |
lext::at( mp_at<lph,iter_type>::value ) == rext::at( mp_at<rph,iter_type>::value ), | |
"error in ublas::prod: permutations of the extents are not correct." | |
); | |
}); | |
constexpr auto const r = pa - q; | |
constexpr auto const s = pb - q; | |
using one_type = boost::mp11::mp_list_c<std::size_t, 1>; | |
using phia1_type = boost::mp11::mp_pop_front< boost::mp11::mp_iota_c<pa + 1> >; | |
using phib1_type = boost::mp11::mp_pop_front< boost::mp11::mp_iota_c<pb + 1> >; | |
using old_nc_type = boost::mp11::mp_repeat_c< one_type ,std::max(r + s, size_type(2)) >; | |
auto phia1 = detail::static_for< sz_phia >([&](auto iter, auto prev){ | |
using prev_type = std::decay_t< decltype(prev) >; | |
using iter_type = decltype(iter); | |
using phia_at_type = boost::mp11::mp_at<phia,iter_type>; | |
using temp_phia1 = boost::mp11::mp_remove<prev_type,phia_at_type>; | |
if constexpr( std::is_same_v<temp_phia1, phia1_type> ){ | |
return temp_phia1{}; | |
}else{ | |
return boost::mp11::mp_append< | |
temp_phia1, | |
boost::mp11::mp_list<phia_at_type> | |
>{}; | |
} | |
}, phia1_type{}); | |
using transformed_phia = std::decay_t< decltype(phia1) >; | |
static_assert( boost::mp11::mp_size<transformed_phia>::value == pa, | |
"error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, " | |
"std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: " | |
"phia after transform should be equal to the extents of lhs tensor" | |
); | |
auto phib1 = detail::static_for< sz_phib >([&](auto iter, auto prev){ | |
using prev_type = std::decay_t< decltype(prev) >; | |
using iter_type = decltype(iter); | |
using phib_at_type = boost::mp11::mp_at<phib,iter_type>; | |
using temp_phib1 = boost::mp11::mp_remove<prev_type,phib_at_type>; | |
if constexpr( std::is_same_v<temp_phib1, phib1_type> ){ | |
return temp_phib1{}; | |
}else{ | |
return boost::mp11::mp_append< | |
temp_phib1, | |
boost::mp11::mp_list<phib_at_type> | |
>{}; | |
} | |
}, phib1_type{}); | |
using transformed_phib = std::decay_t< decltype(phib1) >; | |
static_assert( boost::mp11::mp_size<transformed_phib>::value == pb, | |
"error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, " | |
"std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: " | |
"phib after transform should be equal to the extents of rhs tensor" | |
); | |
auto nc_part1 = detail::static_for<r>([&](auto iter, auto prev){ | |
using prev_type = std::decay_t< decltype(prev) >; | |
using iter_type = decltype(iter); | |
constexpr auto const phia_at = boost::mp11::mp_at<transformed_phia,iter_type>::value - 1; | |
constexpr auto const na_at = extents_type_1::at(phia_at); | |
using new_value = boost::mp11::mp_list_c<std::size_t,na_at>; | |
using temp_nc = boost::mp11::mp_replace_at< prev_type, iter_type, new_value >; | |
return temp_nc{}; | |
}, old_nc_type{}); | |
auto nc = detail::static_for<s>([&](auto iter, auto prev){ | |
using prev_type = std::decay_t< decltype(prev) >; | |
using iter_type = decltype(iter); | |
constexpr auto const phib_at = boost::mp11::mp_at<transformed_phib,iter_type>::value - 1; | |
constexpr auto const nb_at = extents_type_2::at(phib_at); | |
using new_value = boost::mp11::mp_list_c<std::size_t,nb_at>; | |
using temp_nc = boost::mp11::mp_replace_at_c< prev_type, iter_type::value + r, new_value >; | |
return temp_nc{}; | |
}, std::decay_t< decltype(nc_part1) > {} ); | |
using c_extents_type = detail::seq_to_static_extents_t< std::decay_t<decltype(nc)> >; | |
using t_engine = tensor_engine< | |
c_extents_type, | |
std::conditional_t< | |
std::is_same_v< layout_type, first_order >, | |
layout::first_order<c_extents_type>, | |
layout::last_order<c_extents_type> | |
>, | |
rebind_storage_t<c_extents_type,array_type,value_type> | |
>; | |
auto c = tensor_core<t_engine>( c_extents_type{}, value_type{} ); | |
auto new_phia = detail::seq_to_static_extents_t< transformed_phia >{}; | |
auto new_phib = detail::seq_to_static_extents_t< transformed_phib >{}; | |
ttt(pa, pb, q, | |
new_phia.data(), new_phib.data(), | |
c.data(), c.extents().data(), c.strides().data(), | |
a.data(), a.extents().data(), a.strides().data(), | |
b.data(), b.extents().data(), b.strides().data()); | |
return c; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Copyright (c) 2018-2020, Cem Bassoy, cem.bassoy@gmail.com | |
// Copyright (c) 2019-2020, Amit Singh, amitsingh19975@gmail.com | |
// | |
// Distributed under the Boost Software License, Version 1.0. (See | |
// accompanying file LICENSE_1_0.txt or copy at | |
// http://www.boost.org/LICENSE_1_0.txt) | |
// | |
// The authors gratefully acknowledge the support of | |
// | |
#ifndef _BOOST_NUMERIC_UBLAS_DETAIL_UTILITY_HPP_ | |
#define _BOOST_NUMERIC_UBLAS_DETAIL_UTILITY_HPP_ | |
#include <type_traits> | |
#include <utility> | |
#include <boost/mp11/detail/mp_list.hpp> | |
#include <boost/numeric/ublas/tensor/traits/type_traits_extents.hpp> | |
namespace boost::numeric::ublas::detail { | |
template<std::size_t MaxIter,std::size_t I = 0> | |
struct static_for_impl | |
: std::integral_constant<std::size_t, MaxIter> | |
{ | |
template<typename UnaryFn> | |
constexpr auto operator()(UnaryFn fn) const{ | |
if constexpr( MaxIter <= I ){ | |
return; | |
}else{ | |
std::integral_constant<std::size_t, I> info{}; | |
fn(info); | |
static_for_impl<MaxIter,I + 1>{}(std::move(fn)); | |
} | |
} | |
template<typename BinaryFn, typename T> | |
constexpr auto operator()(BinaryFn fn, T ret) const{ | |
if constexpr( MaxIter <= I ){ | |
return ret; | |
}else{ | |
std::integral_constant<std::size_t, I> info{}; | |
auto n_ret = fn(info,ret); | |
return static_for_impl<MaxIter,I + 1>{}(std::move(fn),std::move(n_ret)); | |
} | |
} | |
}; | |
template<std::size_t MaxIter, typename UnaryFn> | |
constexpr auto static_for(UnaryFn fn){ | |
static_for_impl<MaxIter,0ul>{}(std::move(fn)); | |
} | |
template<std::size_t MaxIter, typename T, typename BinaryFn> | |
constexpr auto static_for(BinaryFn fn, T ret){ | |
return std::decay_t< decltype( static_for_impl<MaxIter,0ul>{}(std::move(fn), std::move(ret)) ) > {}; | |
} | |
template<typename T> | |
struct seq_to_static_extents; | |
template<typename T, T... Ns> | |
struct seq_to_static_extents<boost::mp11::mp_list< std::integral_constant<T,Ns>... >> | |
{ | |
using type = basic_static_extents<T,Ns...>; | |
}; | |
template<typename T, T... Ns> | |
struct seq_to_static_extents< std::integer_sequence< T, Ns... > > | |
{ | |
using type = basic_static_extents<T,Ns...>; | |
}; | |
template<typename T> | |
using seq_to_static_extents_t = typename seq_to_static_extents<T>::type; | |
} // namespace boost::numeric::ublas::detail | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment