Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save amitsingh19975/244b98c3ff174dc4d6a547e6609716fc to your computer and use it in GitHub Desktop.
Save amitsingh19975/244b98c3ff174dc4d6a547e6609716fc to your computer and use it in GitHub Desktop.
template <std::size_t q, typename TensorEngine1, typename TensorEngine2, std::size_t... Ss1, std::size_t... Ss2>
inline decltype(auto) prod(tensor_core< TensorEngine1 > const &a, tensor_core< TensorEngine2 > const &b,
std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>)
{
using tensor_type = tensor_core< TensorEngine1 >;
using extents_type_1 = typename tensor_type::extents_type;
using extents_type_2 = typename tensor_core< TensorEngine2 >::extents_type;
using value_type = typename tensor_type::value_type;
using layout_type = typename tensor_type::layout_type;
using size_type = typename extents_type_1::size_type;
using array_type = typename tensor_type::array_type;
using phia = boost::mp11::mp_list_c<std::size_t, Ss1...>;
using phib = boost::mp11::mp_list_c<std::size_t, Ss2...>;
// static_assert(
// std::is_same_v<
// typename tensor_core<TensorEngine1>::resizable_tag,
// typename tensor_core<TensorEngine2>::resizable_tag
// > &&
// std::is_same_v<
// typename tensor_core<TensorEngine1>::resizable_tag,
// storage_static_container_tag
// >,
// "error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, "
// "std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: "
// "Both tensor storage should have the same type of storage and both should be static storage"
// );
static_assert(
is_static< typename tensor_core< TensorEngine1 >::extents_type >::value &&
is_static< typename tensor_core< TensorEngine2 >::extents_type >::value,
"error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, "
"std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: "
"Both tensor should have static extents"
);
constexpr auto const pa = extents_type_1::size_;
constexpr auto const pb = extents_type_2::size_;
constexpr auto const sz_phia = sizeof...(Ss1);
constexpr auto const sz_phib = sizeof...(Ss2);
static_assert(
pa != 0ul,
"error in ublas::prod: order of left-hand side tensor must be greater than 0."
);
static_assert(
pb != 0ul,
"error in ublas::prod: order of right-hand side tensor must be greater than 0."
);
static_assert(
pa > q,
"error in ublas::prod: number of contraction dimensions cannot be greater than the order of the left-hand side tensor."
);
static_assert(
pb > q,
"error in ublas::prod: number of contraction dimensions cannot be greater than the order of the right-hand side tensor."
);
static_assert(
q == sz_phia,
"error in ublas::prod: permutation tuples must have the same length."
);
static_assert(
pa > sz_phia,
"error in ublas::prod: permutation tuple for the left-hand side tensor cannot be greater than the corresponding order."
);
static_assert(
pb > sz_phib,
"error in ublas::prod: permutation tuple for the right-hand side tensor cannot be greater than the corresponding order."
);
auto const& na = a.extents();
auto const& nb = b.extents();
detail::static_for<q>([&](auto iter){
using namespace boost::mp11;
using lext = std::decay_t<decltype(na)>;
using rext = std::decay_t<decltype(nb)>;
using iter_type = decltype(iter);
using lph = phia;
using rph = phib;
static_assert(
lext::at( mp_at<lph,iter_type>::value ) == rext::at( mp_at<rph,iter_type>::value ),
"error in ublas::prod: permutations of the extents are not correct."
);
});
constexpr auto const r = pa - q;
constexpr auto const s = pb - q;
using one_type = boost::mp11::mp_list_c<std::size_t, 1>;
using phia1_type = boost::mp11::mp_pop_front< boost::mp11::mp_iota_c<pa + 1> >;
using phib1_type = boost::mp11::mp_pop_front< boost::mp11::mp_iota_c<pb + 1> >;
using old_nc_type = boost::mp11::mp_repeat_c< one_type ,std::max(r + s, size_type(2)) >;
auto phia1 = detail::static_for< sz_phia >([&](auto iter, auto prev){
using prev_type = std::decay_t< decltype(prev) >;
using iter_type = decltype(iter);
using phia_at_type = boost::mp11::mp_at<phia,iter_type>;
using temp_phia1 = boost::mp11::mp_remove<prev_type,phia_at_type>;
if constexpr( std::is_same_v<temp_phia1, phia1_type> ){
return temp_phia1{};
}else{
return boost::mp11::mp_append<
temp_phia1,
boost::mp11::mp_list<phia_at_type>
>{};
}
}, phia1_type{});
using transformed_phia = std::decay_t< decltype(phia1) >;
static_assert( boost::mp11::mp_size<transformed_phia>::value == pa,
"error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, "
"std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: "
"phia after transform should be equal to the extents of lhs tensor"
);
auto phib1 = detail::static_for< sz_phib >([&](auto iter, auto prev){
using prev_type = std::decay_t< decltype(prev) >;
using iter_type = decltype(iter);
using phib_at_type = boost::mp11::mp_at<phib,iter_type>;
using temp_phib1 = boost::mp11::mp_remove<prev_type,phib_at_type>;
if constexpr( std::is_same_v<temp_phib1, phib1_type> ){
return temp_phib1{};
}else{
return boost::mp11::mp_append<
temp_phib1,
boost::mp11::mp_list<phib_at_type>
>{};
}
}, phib1_type{});
using transformed_phib = std::decay_t< decltype(phib1) >;
static_assert( boost::mp11::mp_size<transformed_phib>::value == pb,
"error in boost::numeric::ublas::prod(tensor_core const&, tensor_core const&, "
"std::integer_sequence<std::size_t, Ss1...>, std::integer_sequence<std::size_t, Ss2...>: "
"phib after transform should be equal to the extents of rhs tensor"
);
auto nc_part1 = detail::static_for<r>([&](auto iter, auto prev){
using prev_type = std::decay_t< decltype(prev) >;
using iter_type = decltype(iter);
constexpr auto const phia_at = boost::mp11::mp_at<transformed_phia,iter_type>::value - 1;
constexpr auto const na_at = extents_type_1::at(phia_at);
using new_value = boost::mp11::mp_list_c<std::size_t,na_at>;
using temp_nc = boost::mp11::mp_replace_at< prev_type, iter_type, new_value >;
return temp_nc{};
}, old_nc_type{});
auto nc = detail::static_for<s>([&](auto iter, auto prev){
using prev_type = std::decay_t< decltype(prev) >;
using iter_type = decltype(iter);
constexpr auto const phib_at = boost::mp11::mp_at<transformed_phib,iter_type>::value - 1;
constexpr auto const nb_at = extents_type_2::at(phib_at);
using new_value = boost::mp11::mp_list_c<std::size_t,nb_at>;
using temp_nc = boost::mp11::mp_replace_at_c< prev_type, iter_type::value + r, new_value >;
return temp_nc{};
}, std::decay_t< decltype(nc_part1) > {} );
using c_extents_type = detail::seq_to_static_extents_t< std::decay_t<decltype(nc)> >;
using t_engine = tensor_engine<
c_extents_type,
std::conditional_t<
std::is_same_v< layout_type, first_order >,
layout::first_order<c_extents_type>,
layout::last_order<c_extents_type>
>,
rebind_storage_t<c_extents_type,array_type,value_type>
>;
auto c = tensor_core<t_engine>( c_extents_type{}, value_type{} );
auto new_phia = detail::seq_to_static_extents_t< transformed_phia >{};
auto new_phib = detail::seq_to_static_extents_t< transformed_phib >{};
ttt(pa, pb, q,
new_phia.data(), new_phib.data(),
c.data(), c.extents().data(), c.strides().data(),
a.data(), a.extents().data(), a.strides().data(),
b.data(), b.extents().data(), b.strides().data());
return c;
}
}
//
// Copyright (c) 2018-2020, Cem Bassoy, cem.bassoy@gmail.com
// Copyright (c) 2019-2020, Amit Singh, amitsingh19975@gmail.com
//
// Distributed under the Boost Software License, Version 1.0. (See
// accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// The authors gratefully acknowledge the support of
// Google
//
#ifndef _BOOST_NUMERIC_UBLAS_DETAIL_UTILITY_HPP_
#define _BOOST_NUMERIC_UBLAS_DETAIL_UTILITY_HPP_
#include <type_traits>
#include <utility>
#include <boost/mp11/detail/mp_list.hpp>
#include <boost/numeric/ublas/tensor/traits/type_traits_extents.hpp>
namespace boost::numeric::ublas::detail {
template<std::size_t MaxIter,std::size_t I = 0>
struct static_for_impl
: std::integral_constant<std::size_t, MaxIter>
{
template<typename UnaryFn>
constexpr auto operator()(UnaryFn fn) const{
if constexpr( MaxIter <= I ){
return;
}else{
std::integral_constant<std::size_t, I> info{};
fn(info);
static_for_impl<MaxIter,I + 1>{}(std::move(fn));
}
}
template<typename BinaryFn, typename T>
constexpr auto operator()(BinaryFn fn, T ret) const{
if constexpr( MaxIter <= I ){
return ret;
}else{
std::integral_constant<std::size_t, I> info{};
auto n_ret = fn(info,ret);
return static_for_impl<MaxIter,I + 1>{}(std::move(fn),std::move(n_ret));
}
}
};
template<std::size_t MaxIter, typename UnaryFn>
constexpr auto static_for(UnaryFn fn){
static_for_impl<MaxIter,0ul>{}(std::move(fn));
}
template<std::size_t MaxIter, typename T, typename BinaryFn>
constexpr auto static_for(BinaryFn fn, T ret){
return std::decay_t< decltype( static_for_impl<MaxIter,0ul>{}(std::move(fn), std::move(ret)) ) > {};
}
template<typename T>
struct seq_to_static_extents;
template<typename T, T... Ns>
struct seq_to_static_extents<boost::mp11::mp_list< std::integral_constant<T,Ns>... >>
{
using type = basic_static_extents<T,Ns...>;
};
template<typename T, T... Ns>
struct seq_to_static_extents< std::integer_sequence< T, Ns... > >
{
using type = basic_static_extents<T,Ns...>;
};
template<typename T>
using seq_to_static_extents_t = typename seq_to_static_extents<T>::type;
} // namespace boost::numeric::ublas::detail
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment