Created
July 24, 2019 12:36
-
-
Save coder3101/a35b630b90c7843b5f0a3eed4d29fb24 to your computer and use it in GitHub Desktop.
Distributive optimization at tensor-level
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct apply_distributive_law { | |
constexpr apply_distributive_law() = default; | |
template <class Expr1, class Expr2> | |
constexpr decltype(auto) | |
operator()(boost::yap::expr_tag<boost::yap::expr_kind::plus>, Expr1 &&e1, | |
Expr2 &&e2) { | |
// We check if the left and right operands of + is a multiply expression. | |
// This returns true if expression is of form of multiply of ublas operands. | |
if constexpr (is_multiply_operand<std::remove_reference_t<Expr1>>::value && | |
is_multiply_operand<std::remove_reference_t<Expr2>>::value) { | |
auto &operand1 = boost::yap::value(boost::yap::left(e1)); | |
auto &operand2 = boost::yap::value(boost::yap::right(e1)); | |
auto &operand3 = boost::yap::value(boost::yap::left(e2)); | |
auto &operand4 = boost::yap::value(boost::yap::right(e2)); | |
using Op_1_t = std::remove_reference_t<decltype(operand1)>; | |
using Op_2_t = std::remove_reference_t<decltype(operand2)>; | |
using Op_3_t = std::remove_reference_t<decltype(operand3)>; | |
using Op_4_t = std::remove_reference_t<decltype(operand4)>; | |
// If they are of pattern A*B + C*D, We extract take references to operands and their types. | |
// We then check if all have same type, it is important. To return a valid uniform data-type we | |
// must make sure that all operands hold same value_type. | |
if constexpr (std::is_same_v<Op_1_t, Op_2_t> && | |
std::is_same_v<Op_2_t, Op_3_t> && | |
std::is_same_v<Op_3_t, Op_4_t>) { | |
// We now build check which operands are matching by address comparision. | |
bool eq_1_3 = std::addressof(operand1) == std::addressof(operand3); | |
bool eq_1_4 = std::addressof(operand1) == std::addressof(operand4); | |
bool eq_2_3 = std::addressof(operand2) == std::addressof(operand3); | |
bool eq_2_4 = std::addressof(operand2) == std::addressof(operand4); | |
// We set a flag usable = true, if we were able to optimize the expression. | |
// Optimization means at any time if match and we can indeed change expression. | |
usable = eq_1_3 || eq_1_4 || eq_2_3 || eq_2_4; | |
// We now build appropriate expression for all three case and return them. | |
if (eq_1_3) { | |
auto inner_op = boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(operand2, operand4); | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::multiplies>(operand1, std::move(inner_op)); | |
} | |
if (eq_1_4) { | |
auto inner_op = boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(operand2, operand3); | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::multiplies>(operand1, std::move(inner_op)); | |
} | |
if (eq_2_3) { | |
auto inner_op = boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(operand1, operand4); | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::multiplies>(operand2, std::move(inner_op)); | |
} | |
if (eq_2_4) { | |
auto inner_op = boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(operand1, operand3); | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::multiplies>(operand2, std::move(inner_op)); | |
} | |
// If none matched, we still return a expression that is of same type as optimized expression | |
// but in this case, usable will be false. | |
// This is just to suppress the warning of non-void return end. | |
// If this is returned the usable flag will be false and user should not | |
// use this returned expression. | |
auto inner_op = boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(operand2, operand4); | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::multiplies>(operand1, std::move(inner_op)); | |
} else { | |
// If all four operands do not have same value_type, we return the same expression as we got. | |
// Setting usable true, means that user can use this returned expression. | |
usable = true; | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(std::forward<Expr1>(e1), | |
std::forward<Expr2>(e2)); | |
} | |
} else { | |
// If expression is not of pattern A*B+C*D, we recurse to the left and right and if both sides are usable flag is set, | |
// We return the new expression. | |
#ifndef BOOST_UBLAS_NO_RECURSIVE_OPTIMIZATION | |
std::remove_reference_t<decltype(*this)> a, b; | |
auto xa = boost::yap::transform( | |
boost::yap::as_expr<boost::numeric::ublas::detail::tensor_expression>( | |
std::forward<Expr1>(e1)), | |
a); | |
auto xb = boost::yap::transform( | |
boost::yap::as_expr<boost::numeric::ublas::detail::tensor_expression>( | |
std::forward<Expr2>(e2)), | |
b); | |
usable = a.usable && b.usable; | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(xa, xb); | |
#else | |
usable = true; | |
return boost::yap::make_expression< | |
boost::numeric::ublas::detail::tensor_expression, | |
boost::yap::expr_kind::plus>(std::forward<Expr1>(e1), | |
std::forward<Expr2>(e2)); | |
#endif | |
} | |
} | |
bool usable = false; | |
}; | |
} | |
// For demonstration purpose only. | |
int main(){ | |
tensor<int> A{}; | |
auto B = A; | |
auto C = A; | |
auto expr = A * B + A * C; | |
auto xform = apply_distributive_law{}; | |
auto new_expr = boost::yap::transform(expr, xform); | |
if(xform.usable) | |
// optimization was a success and the new expr is usable. | |
else | |
// Use, expr optimization has failed. Since operands are not moved, the original expression is still usable. | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment