Last active
June 5, 2020 14:11
-
-
Save amitsingh19975/6f41289179944a7f74f776158c22bae8 to your computer and use it in GitHub Desktop.
dynamic_optimize
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if !defined(AMT_EVAL_HPP) | |
#define AMT_EVAL_HPP | |
#include "exp.hpp" | |
#include <tuple> | |
#include <boost/numeric/ublas/tensor.hpp> | |
#include <boost/fusion/include/for_each.hpp> | |
#include <boost/fusion/tuple/tuple.hpp> | |
namespace amt{ | |
template<typename R, typename T1, typename T2> | |
inline auto apply_op(boost::numeric::ublas::basic_tensor<R>& res | |
, boost::numeric::ublas::basic_tensor<T1> const& a | |
, boost::numeric::ublas::basic_tensor<T2> const& b | |
, Operator op | |
) | |
{ | |
switch(op){ | |
case Operator::ADD:{ | |
res = a + b; | |
return; | |
} | |
case Operator::MUL:{ | |
res = a * b; | |
return; | |
} | |
case Operator::DIV:{ | |
res = a / b; | |
return; | |
} | |
case Operator::SUB:{ | |
res = a - b; | |
return; | |
} | |
default:{ | |
return; | |
} | |
} | |
} | |
template<typename T, typename... Args> | |
inline auto& eval(boost::numeric::ublas::basic_tensor<T>& res, boost::fusion::tuple<Args const&...> const& var, std::shared_ptr<expr_tree> const& e){ | |
namespace fu = boost::fusion; | |
auto postfix = e->flatten(); | |
std::vector<sym> st; | |
for(auto const& el : postfix){ | |
if( !el.is_op() ){ | |
st.push_back(el); | |
}else{ | |
sym a; | |
sym b; | |
if( !st.empty() ){ | |
a = std::move( st.back() ); | |
st.pop_back(); | |
} | |
if( !st.empty() ){ | |
b = std::move( st.back() ); | |
st.pop_back(); | |
} | |
if( a.empty() || b.empty() ){ | |
throw std::runtime_error("amt::eval(tuple,expr_tree): invalid expr_tree"); | |
} | |
bool op_completed{false}; | |
fu::for_each(var,[&a,&b,&res, op = el.op(),&var, &op_completed](auto const& e){ | |
using tensor_type1 = std::decay_t<decltype(e)>; | |
if( !op_completed ){ | |
if( a.var(e) ){ | |
auto& t1 = *reinterpret_cast<tensor_type1 const*>(a.var()); | |
fu::for_each(var,[&b,&t1,&res,&op, &op_completed](auto const& e){ | |
using tensor_type2 = std::decay_t<decltype(e)>; | |
if( b.var(e) ){ | |
auto& t2 = *reinterpret_cast<tensor_type2 const*>(b.var()); | |
apply_op(res, t1, t2, op); | |
op_completed = true; | |
} | |
}); | |
}else if( b.var(e) ){ | |
auto& t1 = *reinterpret_cast<tensor_type1 const*>(b.var()); | |
fu::for_each(var,[&a,&t1,&res,&op, &op_completed](auto const& e){ | |
using tensor_type2 = std::decay_t<decltype(e)>; | |
if( a.var(e) ){ | |
auto& t2 = *reinterpret_cast<tensor_type2 const*>(a.var()); | |
apply_op(res, t2, t1, op); | |
op_completed = true; | |
} | |
}); | |
} | |
} | |
}); | |
st.push_back(res); | |
} | |
} | |
return res; | |
} | |
} // namespace amt | |
#endif // AMT_EVAL_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if !defined(AMT_EXP_HPP) | |
#define AMT_EXP_HPP | |
#include <memory> | |
#include "symbols.hpp" | |
#include <type_traits> | |
namespace amt{ | |
struct expr_tree{ | |
using value_type = sym; | |
using left_expr = std::shared_ptr< expr_tree >; | |
using right_expr = std::shared_ptr< expr_tree >; | |
private: | |
struct node{ | |
node( sym s, left_expr l = nullptr, right_expr r = nullptr ) | |
: m_data(std::move(s)) | |
, m_left(std::move(l)) | |
, m_right(std::move(r)) | |
{} | |
inline constexpr auto const& left() const noexcept{ return m_left; } | |
inline constexpr auto& left() noexcept{ return m_left; } | |
inline constexpr auto const& right() const noexcept{ return m_right; } | |
inline constexpr auto& right() noexcept{ return m_right; } | |
inline constexpr auto const& data() const noexcept{ return m_data; } | |
inline constexpr auto& data() noexcept{ return m_data; } | |
private: | |
value_type m_data; | |
left_expr m_left{nullptr}; | |
right_expr m_right{nullptr}; | |
}; | |
public: | |
using node_type = std::shared_ptr< node >; | |
expr_tree( sym s, left_expr l = nullptr, right_expr r = nullptr ) | |
: m_node( std::make_shared<node>( std::move(s),std::move(l),std::move(r) ) ) | |
{} | |
expr_tree() = default; | |
inline value_type const& data() const{ | |
if( m_node == nullptr ){ | |
throw std::runtime_error("amt::expr_tree::data() : derefencing nullptr"); | |
} | |
return m_node->data(); | |
} | |
inline value_type& data(){ | |
if( m_node == nullptr ){ | |
throw std::runtime_error("amt::expr_tree::data() : derefencing nullptr"); | |
} | |
return m_node->data(); | |
} | |
inline left_expr const& left() const{ | |
if( m_node == nullptr ){ | |
throw std::runtime_error("amt::expr_tree::left() : derefencing nullptr"); | |
} | |
return m_node->left(); | |
} | |
inline left_expr& left(){ | |
if( m_node == nullptr ){ | |
throw std::runtime_error("amt::expr_tree::left() : derefencing nullptr"); | |
} | |
return m_node->left(); | |
} | |
inline right_expr const& right() const{ | |
if( m_node == nullptr ){ | |
throw std::runtime_error("amt::expr_tree::right() : derefencing nullptr"); | |
} | |
return m_node->right(); | |
} | |
inline right_expr& right(){ | |
if( m_node == nullptr ){ | |
throw std::runtime_error("amt::expr_tree::right() : derefencing nullptr"); | |
} | |
return m_node->right(); | |
} | |
inline auto terminal() const noexcept{ | |
return m_node != nullptr && (m_node->left() == nullptr && m_node->right() == nullptr); | |
} | |
operator bool(){ | |
return m_node != nullptr; | |
} | |
inline auto operator==(nullptr_t) const noexcept{ | |
return m_node == nullptr; | |
} | |
inline auto operator!=(nullptr_t) const noexcept{ | |
return m_node != nullptr; | |
} | |
friend auto& operator<<(std::ostream& os, expr_tree const& e){ | |
expr_tree::print_inoder(os,&e); | |
return os; | |
} | |
friend auto& operator<<(std::ostream& os, left_expr const& e){ | |
expr_tree::print_inoder(os,e.get()); | |
return os; | |
} | |
std::vector<sym> flatten() const{ | |
std::vector<sym> temp; | |
flat(temp,this); | |
return temp; | |
} | |
private: | |
static void print_inoder(std::ostream& os, expr_tree const* e){ | |
if( e == nullptr ) return; | |
if( e->left() != nullptr ){ | |
print_inoder(os,e->left().get()); | |
} | |
auto val = e->data(); | |
os<<val; | |
if( e->right() != nullptr ){ | |
print_inoder(os,e->right().get()); | |
} | |
} | |
void flat(std::vector<sym>& vec, expr_tree const* e) const{ | |
if( e == nullptr ) return; | |
if( e->left() != nullptr ){ | |
flat(vec,e->left().get()); | |
} | |
if( e->right() != nullptr ){ | |
flat(vec,e->right().get()); | |
} | |
vec.push_back(e->data()); | |
} | |
private: | |
node_type m_node{nullptr}; | |
}; | |
template<typename T> | |
struct is_expr_list : std::is_same<T, expr_tree>{}; | |
template<typename T> | |
inline constexpr auto const is_expr_list_v = is_expr_list<T>::value; | |
} // namespace amt | |
#endif // AMT_EXP_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <boost/core/demangle.hpp> | |
#include "eval.hpp" | |
#include "rules.hpp" | |
namespace ub = boost::numeric::ublas; | |
template<typename T> | |
std::string type( T const& t ){ | |
return boost::core::demangle(typeid(t).name()); | |
} | |
int main(){ | |
using type = ub::dynamic_tensor<float>; | |
ub::dynamic_tensor<float> a{ ub::dynamic_extents<>{3,3},2.f }; | |
ub::dynamic_tensor<float> b{ ub::dynamic_extents<>{3,3},3.f }; | |
ub::dynamic_tensor<float> c{ ub::dynamic_extents<>{3,3},1.f }; | |
ub::dynamic_tensor<float> res{ ub::dynamic_extents<>{3,3},1.f }; | |
// (a * b) | |
auto left_exp = std::make_shared<amt::expr_tree>( | |
amt::sym(amt::Operator::MUL), | |
std::make_shared<amt::expr_tree>(amt::sym(a)), | |
std::make_shared<amt::expr_tree>(amt::sym(b)) | |
); | |
// (a * c) | |
auto right_exp = std::make_shared<amt::expr_tree>( | |
amt::sym(amt::Operator::MUL), | |
std::make_shared<amt::expr_tree>(amt::sym(a)), | |
std::make_shared<amt::expr_tree>(amt::sym(c)) | |
); | |
// ( a * b ) + ( a + c ) | |
auto temp = std::make_shared<amt::expr_tree>( amt::sym(amt::Operator::ADD), std::move(left_exp), std::move(right_exp)); | |
// a * ( b + c ) | |
amt::detail::rule<0>{}(temp); | |
amt::eval(res,boost::fusion::tuple<type const&,type const&,type const&,type const&> (a, b, c, res),temp); | |
ub::dynamic_tensor<float> test = (a * b) + (a * c); | |
std::cout<<res<<'\n'; | |
std::cout<<test<<'\n'; | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if !defined(AMT_RULE_HPP) | |
#define AMT_RULE_HPP | |
#include "symbols.hpp" | |
#include "exp.hpp" | |
namespace amt::detail{ | |
struct rule_base{ | |
constexpr rule_base() noexcept = default; | |
virtual std::shared_ptr< expr_tree >& operator()(std::shared_ptr< expr_tree >& li) const = 0; | |
virtual ~rule_base(){}; | |
}; | |
} // namespace amt | |
namespace amt::detail{ | |
template<std::ptrdiff_t> | |
struct rule; | |
template<> | |
struct rule<0> : rule_base{ | |
// From : ( a * b ) + ( a * c ) | |
// To : a * ( b + c ) | |
virtual std::shared_ptr< expr_tree >& operator()(std::shared_ptr< expr_tree >& li) const override { | |
if ( li->terminal() ){ | |
return li; | |
}else{ | |
auto& left_expr = li->left(); | |
auto& right_expr = li->right(); | |
auto const& val = li->data(); | |
if ( check(left_expr, right_expr) && val.is_add() ){ | |
auto& a = left_expr->left()->data(); | |
auto& b = left_expr->right()->data(); | |
auto& c = right_expr->right()->data(); | |
auto temp = std::make_shared<expr_tree>( | |
sym(Operator::MUL), | |
std::make_shared<expr_tree>(std::move(a)), | |
std::make_shared<expr_tree>( | |
Operator::ADD, | |
std::make_shared<expr_tree>(std::move(b)), | |
std::make_shared<expr_tree>(std::move(c)) | |
) | |
); | |
li = std::move(temp); | |
} | |
} | |
return li; | |
} | |
private: | |
inline constexpr bool check( std::shared_ptr< expr_tree > const& l, std::shared_ptr< expr_tree > const& r ) const noexcept{ | |
if( l->terminal() || r->terminal() ){ | |
return false; | |
}else{ | |
return l->left()->data() == r->left()->data() && | |
l->data() == r->data() && l->data().is_mul(); | |
} | |
} | |
}; | |
} // namespace amt::detal | |
#endif // AMT_RULE_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if !defined(AMT_SYMBOLS_HPP) | |
#define AMT_SYMBOLS_HPP | |
#include <cstddef> | |
#include <type_traits> | |
#include <boost/numeric/ublas/tensor.hpp> | |
#include <unordered_map> | |
namespace amt{ | |
enum class Operator{ | |
ADD, | |
SUB, | |
MUL, | |
DIV, | |
None | |
}; | |
enum class SymType{ | |
OP, | |
NUM, | |
VAR, | |
None | |
}; | |
struct sym{ | |
constexpr sym() noexcept = default; | |
template<typename T, | |
typename = std::enable_if_t< !std::is_integral_v<T>, int > | |
> | |
constexpr sym(T const& t) noexcept | |
: m_var( static_cast<void const*>(&t) ) | |
, m_type(SymType::VAR) | |
{} | |
constexpr sym(std::ptrdiff_t const& val) noexcept | |
: m_int( val ) | |
, m_type(SymType::NUM) | |
{} | |
constexpr sym(Operator op) noexcept | |
: m_op( op ) | |
, m_type(SymType::OP) | |
{} | |
inline constexpr auto empty() const noexcept{ | |
return m_type == SymType::None; | |
} | |
inline constexpr auto op() const noexcept{ | |
return m_op; | |
} | |
inline constexpr auto is_num() const noexcept{ | |
return m_type == SymType::NUM; | |
} | |
inline constexpr auto is_var() const noexcept{ | |
return m_type == SymType::VAR; | |
} | |
inline constexpr auto is_op() const noexcept{ | |
return m_type == SymType::OP; | |
} | |
inline constexpr auto is_add() const noexcept{ | |
return m_op == Operator::ADD; | |
} | |
inline constexpr auto is_mul() const noexcept{ | |
return m_op == Operator::MUL; | |
} | |
inline constexpr auto is_sub() const noexcept{ | |
return m_op == Operator::SUB; | |
} | |
inline constexpr auto is_div() const noexcept{ | |
return m_op == Operator::DIV; | |
} | |
inline constexpr auto num() const noexcept{ | |
return m_int; | |
} | |
inline constexpr auto var() const noexcept{ | |
return m_var; | |
} | |
template<typename T> | |
inline constexpr auto var(T const& v) const noexcept{ | |
return m_var == std::addressof(v); | |
} | |
inline constexpr auto operator==(sym const& other) const noexcept{ | |
return ( m_var == other.m_var ) && ( m_int == other.m_int ) && ( m_op == other.m_op ) | |
&& ( m_type == other.m_type ); | |
} | |
inline constexpr auto operator!=(sym const& other) const noexcept{ | |
return !(*this == other); | |
} | |
friend auto& operator<<(std::ostream& os, sym const& s){ | |
if ( s.is_var() ) os<<"var ( "<<s.m_var<<" ) "; | |
if ( s.is_num() ) os<<s.num()<<" "; | |
if ( s.is_op() ) { | |
if( s.is_add() ) os<<"+ "; | |
if( s.is_mul() ) os<<"* "; | |
if( s.is_sub() ) os<<"- "; | |
if( s.is_div() ) os<<"/ "; | |
} | |
return os; | |
} | |
private: | |
void const* m_var{nullptr}; | |
std::ptrdiff_t m_int{0}; | |
Operator m_op{Operator::None}; | |
SymType m_type{SymType::None}; | |
}; | |
} // namespace amt::sym | |
#endif // AMT_SYMBOLS_HPP |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment