Skip to content

Instantly share code, notes, and snippets.

@coder3101
Created July 24, 2019 12:36
Show Gist options
  • Save coder3101/a35b630b90c7843b5f0a3eed4d29fb24 to your computer and use it in GitHub Desktop.
Save coder3101/a35b630b90c7843b5f0a3eed4d29fb24 to your computer and use it in GitHub Desktop.
Distributive optimization at tensor-level
struct apply_distributive_law {
constexpr apply_distributive_law() = default;
template <class Expr1, class Expr2>
constexpr decltype(auto)
operator()(boost::yap::expr_tag<boost::yap::expr_kind::plus>, Expr1 &&e1,
Expr2 &&e2) {
// We check if the left and right operands of + is a multiply expression.
// This returns true if expression is of form of multiply of ublas operands.
if constexpr (is_multiply_operand<std::remove_reference_t<Expr1>>::value &&
is_multiply_operand<std::remove_reference_t<Expr2>>::value) {
auto &operand1 = boost::yap::value(boost::yap::left(e1));
auto &operand2 = boost::yap::value(boost::yap::right(e1));
auto &operand3 = boost::yap::value(boost::yap::left(e2));
auto &operand4 = boost::yap::value(boost::yap::right(e2));
using Op_1_t = std::remove_reference_t<decltype(operand1)>;
using Op_2_t = std::remove_reference_t<decltype(operand2)>;
using Op_3_t = std::remove_reference_t<decltype(operand3)>;
using Op_4_t = std::remove_reference_t<decltype(operand4)>;
// If they are of pattern A*B + C*D, We extract take references to operands and their types.
// We then check if all have same type, it is important. To return a valid uniform data-type we
// must make sure that all operands hold same value_type.
if constexpr (std::is_same_v<Op_1_t, Op_2_t> &&
std::is_same_v<Op_2_t, Op_3_t> &&
std::is_same_v<Op_3_t, Op_4_t>) {
// We now build check which operands are matching by address comparision.
bool eq_1_3 = std::addressof(operand1) == std::addressof(operand3);
bool eq_1_4 = std::addressof(operand1) == std::addressof(operand4);
bool eq_2_3 = std::addressof(operand2) == std::addressof(operand3);
bool eq_2_4 = std::addressof(operand2) == std::addressof(operand4);
// We set a flag usable = true, if we were able to optimize the expression.
// Optimization means at any time if match and we can indeed change expression.
usable = eq_1_3 || eq_1_4 || eq_2_3 || eq_2_4;
// We now build appropriate expression for all three case and return them.
if (eq_1_3) {
auto inner_op = boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(operand2, operand4);
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::multiplies>(operand1, std::move(inner_op));
}
if (eq_1_4) {
auto inner_op = boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(operand2, operand3);
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::multiplies>(operand1, std::move(inner_op));
}
if (eq_2_3) {
auto inner_op = boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(operand1, operand4);
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::multiplies>(operand2, std::move(inner_op));
}
if (eq_2_4) {
auto inner_op = boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(operand1, operand3);
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::multiplies>(operand2, std::move(inner_op));
}
// If none matched, we still return a expression that is of same type as optimized expression
// but in this case, usable will be false.
// This is just to suppress the warning of non-void return end.
// If this is returned the usable flag will be false and user should not
// use this returned expression.
auto inner_op = boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(operand2, operand4);
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::multiplies>(operand1, std::move(inner_op));
} else {
// If all four operands do not have same value_type, we return the same expression as we got.
// Setting usable true, means that user can use this returned expression.
usable = true;
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(std::forward<Expr1>(e1),
std::forward<Expr2>(e2));
}
} else {
// If expression is not of pattern A*B+C*D, we recurse to the left and right and if both sides are usable flag is set,
// We return the new expression.
#ifndef BOOST_UBLAS_NO_RECURSIVE_OPTIMIZATION
std::remove_reference_t<decltype(*this)> a, b;
auto xa = boost::yap::transform(
boost::yap::as_expr<boost::numeric::ublas::detail::tensor_expression>(
std::forward<Expr1>(e1)),
a);
auto xb = boost::yap::transform(
boost::yap::as_expr<boost::numeric::ublas::detail::tensor_expression>(
std::forward<Expr2>(e2)),
b);
usable = a.usable && b.usable;
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(xa, xb);
#else
usable = true;
return boost::yap::make_expression<
boost::numeric::ublas::detail::tensor_expression,
boost::yap::expr_kind::plus>(std::forward<Expr1>(e1),
std::forward<Expr2>(e2));
#endif
}
}
bool usable = false;
};
}
// For demonstration purpose only.
int main(){
tensor<int> A{};
auto B = A;
auto C = A;
auto expr = A * B + A * C;
auto xform = apply_distributive_law{};
auto new_expr = boost::yap::transform(expr, xform);
if(xform.usable)
// optimization was a success and the new expr is usable.
else
// Use, expr optimization has failed. Since operands are not moved, the original expression is still usable.
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment