Skip to content

Instantly share code, notes, and snippets.

@amitsingh19975
Last active June 5, 2020 14:11
Show Gist options
  • Save amitsingh19975/6f41289179944a7f74f776158c22bae8 to your computer and use it in GitHub Desktop.
Save amitsingh19975/6f41289179944a7f74f776158c22bae8 to your computer and use it in GitHub Desktop.
dynamic_optimize
#if !defined(AMT_EVAL_HPP)
#define AMT_EVAL_HPP
#include "exp.hpp"
#include <tuple>
#include <boost/numeric/ublas/tensor.hpp>
#include <boost/fusion/include/for_each.hpp>
#include <boost/fusion/tuple/tuple.hpp>
namespace amt{
template<typename R, typename T1, typename T2>
inline auto apply_op(boost::numeric::ublas::basic_tensor<R>& res
, boost::numeric::ublas::basic_tensor<T1> const& a
, boost::numeric::ublas::basic_tensor<T2> const& b
, Operator op
)
{
switch(op){
case Operator::ADD:{
res = a + b;
return;
}
case Operator::MUL:{
res = a * b;
return;
}
case Operator::DIV:{
res = a / b;
return;
}
case Operator::SUB:{
res = a - b;
return;
}
default:{
return;
}
}
}
template<typename T, typename... Args>
inline auto& eval(boost::numeric::ublas::basic_tensor<T>& res, boost::fusion::tuple<Args const&...> const& var, std::shared_ptr<expr_tree> const& e){
namespace fu = boost::fusion;
auto postfix = e->flatten();
std::vector<sym> st;
for(auto const& el : postfix){
if( !el.is_op() ){
st.push_back(el);
}else{
sym a;
sym b;
if( !st.empty() ){
a = std::move( st.back() );
st.pop_back();
}
if( !st.empty() ){
b = std::move( st.back() );
st.pop_back();
}
if( a.empty() || b.empty() ){
throw std::runtime_error("amt::eval(tuple,expr_tree): invalid expr_tree");
}
bool op_completed{false};
fu::for_each(var,[&a,&b,&res, op = el.op(),&var, &op_completed](auto const& e){
using tensor_type1 = std::decay_t<decltype(e)>;
if( !op_completed ){
if( a.var(e) ){
auto& t1 = *reinterpret_cast<tensor_type1 const*>(a.var());
fu::for_each(var,[&b,&t1,&res,&op, &op_completed](auto const& e){
using tensor_type2 = std::decay_t<decltype(e)>;
if( b.var(e) ){
auto& t2 = *reinterpret_cast<tensor_type2 const*>(b.var());
apply_op(res, t1, t2, op);
op_completed = true;
}
});
}else if( b.var(e) ){
auto& t1 = *reinterpret_cast<tensor_type1 const*>(b.var());
fu::for_each(var,[&a,&t1,&res,&op, &op_completed](auto const& e){
using tensor_type2 = std::decay_t<decltype(e)>;
if( a.var(e) ){
auto& t2 = *reinterpret_cast<tensor_type2 const*>(a.var());
apply_op(res, t2, t1, op);
op_completed = true;
}
});
}
}
});
st.push_back(res);
}
}
return res;
}
} // namespace amt
#endif // AMT_EVAL_HPP
#if !defined(AMT_EXP_HPP)
#define AMT_EXP_HPP
#include <memory>
#include "symbols.hpp"
#include <type_traits>
namespace amt{
struct expr_tree{
using value_type = sym;
using left_expr = std::shared_ptr< expr_tree >;
using right_expr = std::shared_ptr< expr_tree >;
private:
struct node{
node( sym s, left_expr l = nullptr, right_expr r = nullptr )
: m_data(std::move(s))
, m_left(std::move(l))
, m_right(std::move(r))
{}
inline constexpr auto const& left() const noexcept{ return m_left; }
inline constexpr auto& left() noexcept{ return m_left; }
inline constexpr auto const& right() const noexcept{ return m_right; }
inline constexpr auto& right() noexcept{ return m_right; }
inline constexpr auto const& data() const noexcept{ return m_data; }
inline constexpr auto& data() noexcept{ return m_data; }
private:
value_type m_data;
left_expr m_left{nullptr};
right_expr m_right{nullptr};
};
public:
using node_type = std::shared_ptr< node >;
expr_tree( sym s, left_expr l = nullptr, right_expr r = nullptr )
: m_node( std::make_shared<node>( std::move(s),std::move(l),std::move(r) ) )
{}
expr_tree() = default;
inline value_type const& data() const{
if( m_node == nullptr ){
throw std::runtime_error("amt::expr_tree::data() : derefencing nullptr");
}
return m_node->data();
}
inline value_type& data(){
if( m_node == nullptr ){
throw std::runtime_error("amt::expr_tree::data() : derefencing nullptr");
}
return m_node->data();
}
inline left_expr const& left() const{
if( m_node == nullptr ){
throw std::runtime_error("amt::expr_tree::left() : derefencing nullptr");
}
return m_node->left();
}
inline left_expr& left(){
if( m_node == nullptr ){
throw std::runtime_error("amt::expr_tree::left() : derefencing nullptr");
}
return m_node->left();
}
inline right_expr const& right() const{
if( m_node == nullptr ){
throw std::runtime_error("amt::expr_tree::right() : derefencing nullptr");
}
return m_node->right();
}
inline right_expr& right(){
if( m_node == nullptr ){
throw std::runtime_error("amt::expr_tree::right() : derefencing nullptr");
}
return m_node->right();
}
inline auto terminal() const noexcept{
return m_node != nullptr && (m_node->left() == nullptr && m_node->right() == nullptr);
}
operator bool(){
return m_node != nullptr;
}
inline auto operator==(nullptr_t) const noexcept{
return m_node == nullptr;
}
inline auto operator!=(nullptr_t) const noexcept{
return m_node != nullptr;
}
friend auto& operator<<(std::ostream& os, expr_tree const& e){
expr_tree::print_inoder(os,&e);
return os;
}
friend auto& operator<<(std::ostream& os, left_expr const& e){
expr_tree::print_inoder(os,e.get());
return os;
}
std::vector<sym> flatten() const{
std::vector<sym> temp;
flat(temp,this);
return temp;
}
private:
static void print_inoder(std::ostream& os, expr_tree const* e){
if( e == nullptr ) return;
if( e->left() != nullptr ){
print_inoder(os,e->left().get());
}
auto val = e->data();
os<<val;
if( e->right() != nullptr ){
print_inoder(os,e->right().get());
}
}
void flat(std::vector<sym>& vec, expr_tree const* e) const{
if( e == nullptr ) return;
if( e->left() != nullptr ){
flat(vec,e->left().get());
}
if( e->right() != nullptr ){
flat(vec,e->right().get());
}
vec.push_back(e->data());
}
private:
node_type m_node{nullptr};
};
template<typename T>
struct is_expr_list : std::is_same<T, expr_tree>{};
template<typename T>
inline constexpr auto const is_expr_list_v = is_expr_list<T>::value;
} // namespace amt
#endif // AMT_EXP_HPP
#include <iostream>
#include <boost/core/demangle.hpp>
#include "eval.hpp"
#include "rules.hpp"
namespace ub = boost::numeric::ublas;
template<typename T>
std::string type( T const& t ){
return boost::core::demangle(typeid(t).name());
}
int main(){
using type = ub::dynamic_tensor<float>;
ub::dynamic_tensor<float> a{ ub::dynamic_extents<>{3,3},2.f };
ub::dynamic_tensor<float> b{ ub::dynamic_extents<>{3,3},3.f };
ub::dynamic_tensor<float> c{ ub::dynamic_extents<>{3,3},1.f };
ub::dynamic_tensor<float> res{ ub::dynamic_extents<>{3,3},1.f };
// (a * b)
auto left_exp = std::make_shared<amt::expr_tree>(
amt::sym(amt::Operator::MUL),
std::make_shared<amt::expr_tree>(amt::sym(a)),
std::make_shared<amt::expr_tree>(amt::sym(b))
);
// (a * c)
auto right_exp = std::make_shared<amt::expr_tree>(
amt::sym(amt::Operator::MUL),
std::make_shared<amt::expr_tree>(amt::sym(a)),
std::make_shared<amt::expr_tree>(amt::sym(c))
);
// ( a * b ) + ( a + c )
auto temp = std::make_shared<amt::expr_tree>( amt::sym(amt::Operator::ADD), std::move(left_exp), std::move(right_exp));
// a * ( b + c )
amt::detail::rule<0>{}(temp);
amt::eval(res,boost::fusion::tuple<type const&,type const&,type const&,type const&> (a, b, c, res),temp);
ub::dynamic_tensor<float> test = (a * b) + (a * c);
std::cout<<res<<'\n';
std::cout<<test<<'\n';
return 0;
}
#if !defined(AMT_RULE_HPP)
#define AMT_RULE_HPP
#include "symbols.hpp"
#include "exp.hpp"
namespace amt::detail{
struct rule_base{
constexpr rule_base() noexcept = default;
virtual std::shared_ptr< expr_tree >& operator()(std::shared_ptr< expr_tree >& li) const = 0;
virtual ~rule_base(){};
};
} // namespace amt
namespace amt::detail{
template<std::ptrdiff_t>
struct rule;
template<>
struct rule<0> : rule_base{
// From : ( a * b ) + ( a * c )
// To : a * ( b + c )
virtual std::shared_ptr< expr_tree >& operator()(std::shared_ptr< expr_tree >& li) const override {
if ( li->terminal() ){
return li;
}else{
auto& left_expr = li->left();
auto& right_expr = li->right();
auto const& val = li->data();
if ( check(left_expr, right_expr) && val.is_add() ){
auto& a = left_expr->left()->data();
auto& b = left_expr->right()->data();
auto& c = right_expr->right()->data();
auto temp = std::make_shared<expr_tree>(
sym(Operator::MUL),
std::make_shared<expr_tree>(std::move(a)),
std::make_shared<expr_tree>(
Operator::ADD,
std::make_shared<expr_tree>(std::move(b)),
std::make_shared<expr_tree>(std::move(c))
)
);
li = std::move(temp);
}
}
return li;
}
private:
inline constexpr bool check( std::shared_ptr< expr_tree > const& l, std::shared_ptr< expr_tree > const& r ) const noexcept{
if( l->terminal() || r->terminal() ){
return false;
}else{
return l->left()->data() == r->left()->data() &&
l->data() == r->data() && l->data().is_mul();
}
}
};
} // namespace amt::detal
#endif // AMT_RULE_HPP
#if !defined(AMT_SYMBOLS_HPP)
#define AMT_SYMBOLS_HPP
#include <cstddef>
#include <type_traits>
#include <boost/numeric/ublas/tensor.hpp>
#include <unordered_map>
namespace amt{
enum class Operator{
ADD,
SUB,
MUL,
DIV,
None
};
enum class SymType{
OP,
NUM,
VAR,
None
};
struct sym{
constexpr sym() noexcept = default;
template<typename T,
typename = std::enable_if_t< !std::is_integral_v<T>, int >
>
constexpr sym(T const& t) noexcept
: m_var( static_cast<void const*>(&t) )
, m_type(SymType::VAR)
{}
constexpr sym(std::ptrdiff_t const& val) noexcept
: m_int( val )
, m_type(SymType::NUM)
{}
constexpr sym(Operator op) noexcept
: m_op( op )
, m_type(SymType::OP)
{}
inline constexpr auto empty() const noexcept{
return m_type == SymType::None;
}
inline constexpr auto op() const noexcept{
return m_op;
}
inline constexpr auto is_num() const noexcept{
return m_type == SymType::NUM;
}
inline constexpr auto is_var() const noexcept{
return m_type == SymType::VAR;
}
inline constexpr auto is_op() const noexcept{
return m_type == SymType::OP;
}
inline constexpr auto is_add() const noexcept{
return m_op == Operator::ADD;
}
inline constexpr auto is_mul() const noexcept{
return m_op == Operator::MUL;
}
inline constexpr auto is_sub() const noexcept{
return m_op == Operator::SUB;
}
inline constexpr auto is_div() const noexcept{
return m_op == Operator::DIV;
}
inline constexpr auto num() const noexcept{
return m_int;
}
inline constexpr auto var() const noexcept{
return m_var;
}
template<typename T>
inline constexpr auto var(T const& v) const noexcept{
return m_var == std::addressof(v);
}
inline constexpr auto operator==(sym const& other) const noexcept{
return ( m_var == other.m_var ) && ( m_int == other.m_int ) && ( m_op == other.m_op )
&& ( m_type == other.m_type );
}
inline constexpr auto operator!=(sym const& other) const noexcept{
return !(*this == other);
}
friend auto& operator<<(std::ostream& os, sym const& s){
if ( s.is_var() ) os<<"var ( "<<s.m_var<<" ) ";
if ( s.is_num() ) os<<s.num()<<" ";
if ( s.is_op() ) {
if( s.is_add() ) os<<"+ ";
if( s.is_mul() ) os<<"* ";
if( s.is_sub() ) os<<"- ";
if( s.is_div() ) os<<"/ ";
}
return os;
}
private:
void const* m_var{nullptr};
std::ptrdiff_t m_int{0};
Operator m_op{Operator::None};
SymType m_type{SymType::None};
};
} // namespace amt::sym
#endif // AMT_SYMBOLS_HPP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment