Skip to content

Instantly share code, notes, and snippets.

@jamornsriwasansak
Created June 21, 2021 21:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamornsriwasansak/6fa8e783ec910bd0f755975663565213 to your computer and use it in GitHub Desktop.
Save jamornsriwasansak/6fa8e783ec910bd0f755975663565213 to your computer and use it in GitHub Desktop.
#pragma once
#include <cassert>
#include <memory>
enum class Op
{
cnst,
var,
add,
sub,
mul,
div
};
// rename shared_ptr to ptr
template<typename ValT>
using ptr = std::shared_ptr<ValT>;
// rename make_shared to alloc
template<typename ValT, class... Args>
inline ptr<ValT>
alloc(Args&&... args)
{
return std::make_shared<ValT>(args...);
}
template<typename ValT>
struct Expr
{
virtual ValT eval() const = 0;
virtual ValT diffEval(const int var_index) const = 0;
virtual int getVarIndex() const { return std::numeric_limits<int>::min(); }
virtual void set(const ValT& val) {}
virtual ValT& get() {}
};
template<typename ValT>
struct ConstantExpr : public Expr<ValT>
{
ValT m_const_val;
ConstantExpr(const ValT& val) : m_const_val(val) {}
ValT eval() const override { return m_const_val; }
ValT diffEval(const int var_index) const override { return 0.0f; }
};
template<typename ValT, int VarIndex>
struct VariableExpr : public Expr<ValT>
{
ValT m_val;
VariableExpr() {}
VariableExpr(const ValT& val) : m_val(val) {}
ValT eval() const override { return m_val; }
ValT diffEval(const int var_index) const override { return (var_index == VarIndex) ? 1.0f : 0.0f; }
int getVarIndex() const override { return VarIndex; }
void set(const ValT& val) override { m_val = val; }
ValT& get() override { return m_val; }
};
template<typename ValT>
struct BinaryExpr : public Expr<ValT>
{
ptr<Expr<ValT>> m_lhs;
ptr<Expr<ValT>> m_rhs;
BinaryExpr(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs) : m_lhs(lhs), m_rhs(rhs) {}
};
template<typename ValT>
struct AddExpr : public BinaryExpr<ValT>
{
using BinaryExpr<ValT>::BinaryExpr;
ValT eval() const override { return this->m_lhs->eval() + this->m_rhs->eval(); }
ValT diffEval(const int var_index) const override { return this->m_lhs->diffEval(var_index) + this->m_rhs->diffEval(var_index); }
};
template<typename ValT>
struct SubExpr : public BinaryExpr<ValT>
{
using BinaryExpr<ValT>::BinaryExpr;
ValT eval() const override { return this->m_lhs->eval() - this->m_rhs->eval(); }
ValT diffEval(const int var_index) const override { return this->m_lhs->diffEval(var_index) - this->m_rhs->diffEval(var_index); }
};
template<typename ValT>
struct MulExpr : public BinaryExpr<ValT>
{
using BinaryExpr<ValT>::BinaryExpr;
ValT eval() const override { return this->m_lhs->eval() * this->m_rhs->eval(); }
ValT diffEval(const int var_index) const override
{
ValT l = this->m_lhs->eval();
ValT dl = this->m_lhs->diffEval(var_index);
ValT r = this->m_rhs->eval();
ValT dr = this->m_rhs->diffEval(var_index);
return l * dr + r * dl;
}
};
template<typename ValT>
struct DivExpr : public BinaryExpr<ValT>
{
using BinaryExpr<ValT>::BinaryExpr;
ValT eval() const { return this->m_lhs->eval() / this->m_rhs->eval(); }
ValT diffEval(const int var_index) const
{
ValT l = this->m_lhs->eval();
ValT dl = this->m_lhs->diffEval(var_index);
ValT r = this->m_rhs->eval();
ValT dr = this->m_rhs->diffEval(var_index);
return (r * dl - l * dr) / (r * r);
}
};
template<typename ValT>
ptr<Expr<ValT>>
Constant(const ValT& val)
{
return alloc<ConstantExpr<ValT>>(val);
}
template<typename ValT, int VarIndex>
ptr<Expr<ValT>>
Variable()
{
return alloc<VariableExpr<ValT, VarIndex>>();
}
template<typename ValT>
ptr<Expr<ValT>>
operator+(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs)
{
return alloc<AddExpr<ValT>>(lhs, rhs);
}
template<typename ValT>
ptr<Expr<ValT>>
operator+(const ptr<Expr<ValT>>& lhs, const ValT& rhs)
{
return alloc<AddExpr<ValT>>(lhs, Constant<ValT>(rhs));
}
template<typename ValT>
ptr<Expr<ValT>>
operator+(const ValT& lhs, const ptr<Expr<ValT>>& rhs)
{
return alloc<AddExpr<ValT>>(Constant<ValT>(lhs), rhs);
}
template<typename ValT>
ptr<Expr<ValT>>
operator-(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs)
{
return alloc<SubExpr<ValT>>(lhs, rhs);
}
template<typename ValT>
ptr<Expr<ValT>>
operator-(const ptr<Expr<ValT>>& lhs, const ValT& rhs)
{
return alloc<SubExpr<ValT>>(lhs, Constant<ValT>(rhs));
}
template<typename ValT>
ptr<Expr<ValT>>
operator-(const ValT& lhs, const ptr<Expr<ValT>>& rhs)
{
return alloc<SubExpr<ValT>>(Constant<ValT>(lhs), rhs);
}
template<typename ValT>
ptr<Expr<ValT>>
operator*(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs)
{
return alloc<MulExpr<ValT>>(lhs, rhs);
}
template<typename ValT>
ptr<Expr<ValT>>
operator*(const ptr<Expr<ValT>>& lhs, const ValT& rhs)
{
return alloc<MulExpr<ValT>>(lhs, Constant<ValT>(rhs));
}
template<typename ValT>
ptr<Expr<ValT>>
operator*(const ValT& lhs, const ptr<Expr<ValT>>& rhs)
{
return alloc<MulExpr<ValT>>(Constant<ValT>(lhs), rhs);
}
#define DECL_DIFF_OP(Op, OpType) \
template<typename ValT>\
ptr<Expr<ValT>> operator##Op##(const ptr<Expr<ValT>> & lhs, const ptr<Expr<ValT>> & rhs)\
{\
return alloc<OpType<ValT>>(lhs, rhs);\
}\
template<typename ValT>\
ptr<Expr<ValT>> operator##Op##(const ptr<Expr<ValT>> & lhs, const ValT & rhs)\
{\
return alloc<OpType<ValT>>(lhs, Constant<ValT>(rhs));\
}\
template<typename ValT>\
ptr<Expr<ValT>> operator##Op##(const ValT & lhs, const ptr<Expr<ValT>> & rhs)\
{\
return alloc<OpType<ValT>>(Constant<ValT>(lhs), rhs);\
}
DECL_DIFF_OP(+, AddExpr);
DECL_DIFF_OP(*, MulExpr);
DECL_DIFF_OP(/ , DivExpr);
DECL_DIFF_OP(-, SubExpr);
#undef DECL_DIFF_OP
template<typename ValT>
using Diff = ptr<Expr<ValT>>;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment