Created
June 21, 2021 21:13
-
-
Save jamornsriwasansak/6fa8e783ec910bd0f755975663565213 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <cassert> | |
#include <memory> | |
enum class Op | |
{ | |
cnst, | |
var, | |
add, | |
sub, | |
mul, | |
div | |
}; | |
// rename shared_ptr to ptr | |
template<typename ValT> | |
using ptr = std::shared_ptr<ValT>; | |
// rename make_shared to alloc | |
template<typename ValT, class... Args> | |
inline ptr<ValT> | |
alloc(Args&&... args) | |
{ | |
return std::make_shared<ValT>(args...); | |
} | |
template<typename ValT> | |
struct Expr | |
{ | |
virtual ValT eval() const = 0; | |
virtual ValT diffEval(const int var_index) const = 0; | |
virtual int getVarIndex() const { return std::numeric_limits<int>::min(); } | |
virtual void set(const ValT& val) {} | |
virtual ValT& get() {} | |
}; | |
template<typename ValT> | |
struct ConstantExpr : public Expr<ValT> | |
{ | |
ValT m_const_val; | |
ConstantExpr(const ValT& val) : m_const_val(val) {} | |
ValT eval() const override { return m_const_val; } | |
ValT diffEval(const int var_index) const override { return 0.0f; } | |
}; | |
template<typename ValT, int VarIndex> | |
struct VariableExpr : public Expr<ValT> | |
{ | |
ValT m_val; | |
VariableExpr() {} | |
VariableExpr(const ValT& val) : m_val(val) {} | |
ValT eval() const override { return m_val; } | |
ValT diffEval(const int var_index) const override { return (var_index == VarIndex) ? 1.0f : 0.0f; } | |
int getVarIndex() const override { return VarIndex; } | |
void set(const ValT& val) override { m_val = val; } | |
ValT& get() override { return m_val; } | |
}; | |
template<typename ValT> | |
struct BinaryExpr : public Expr<ValT> | |
{ | |
ptr<Expr<ValT>> m_lhs; | |
ptr<Expr<ValT>> m_rhs; | |
BinaryExpr(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs) : m_lhs(lhs), m_rhs(rhs) {} | |
}; | |
template<typename ValT> | |
struct AddExpr : public BinaryExpr<ValT> | |
{ | |
using BinaryExpr<ValT>::BinaryExpr; | |
ValT eval() const override { return this->m_lhs->eval() + this->m_rhs->eval(); } | |
ValT diffEval(const int var_index) const override { return this->m_lhs->diffEval(var_index) + this->m_rhs->diffEval(var_index); } | |
}; | |
template<typename ValT> | |
struct SubExpr : public BinaryExpr<ValT> | |
{ | |
using BinaryExpr<ValT>::BinaryExpr; | |
ValT eval() const override { return this->m_lhs->eval() - this->m_rhs->eval(); } | |
ValT diffEval(const int var_index) const override { return this->m_lhs->diffEval(var_index) - this->m_rhs->diffEval(var_index); } | |
}; | |
template<typename ValT> | |
struct MulExpr : public BinaryExpr<ValT> | |
{ | |
using BinaryExpr<ValT>::BinaryExpr; | |
ValT eval() const override { return this->m_lhs->eval() * this->m_rhs->eval(); } | |
ValT diffEval(const int var_index) const override | |
{ | |
ValT l = this->m_lhs->eval(); | |
ValT dl = this->m_lhs->diffEval(var_index); | |
ValT r = this->m_rhs->eval(); | |
ValT dr = this->m_rhs->diffEval(var_index); | |
return l * dr + r * dl; | |
} | |
}; | |
template<typename ValT> | |
struct DivExpr : public BinaryExpr<ValT> | |
{ | |
using BinaryExpr<ValT>::BinaryExpr; | |
ValT eval() const { return this->m_lhs->eval() / this->m_rhs->eval(); } | |
ValT diffEval(const int var_index) const | |
{ | |
ValT l = this->m_lhs->eval(); | |
ValT dl = this->m_lhs->diffEval(var_index); | |
ValT r = this->m_rhs->eval(); | |
ValT dr = this->m_rhs->diffEval(var_index); | |
return (r * dl - l * dr) / (r * r); | |
} | |
}; | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
Constant(const ValT& val) | |
{ | |
return alloc<ConstantExpr<ValT>>(val); | |
} | |
template<typename ValT, int VarIndex> | |
ptr<Expr<ValT>> | |
Variable() | |
{ | |
return alloc<VariableExpr<ValT, VarIndex>>(); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator+(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs) | |
{ | |
return alloc<AddExpr<ValT>>(lhs, rhs); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator+(const ptr<Expr<ValT>>& lhs, const ValT& rhs) | |
{ | |
return alloc<AddExpr<ValT>>(lhs, Constant<ValT>(rhs)); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator+(const ValT& lhs, const ptr<Expr<ValT>>& rhs) | |
{ | |
return alloc<AddExpr<ValT>>(Constant<ValT>(lhs), rhs); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator-(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs) | |
{ | |
return alloc<SubExpr<ValT>>(lhs, rhs); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator-(const ptr<Expr<ValT>>& lhs, const ValT& rhs) | |
{ | |
return alloc<SubExpr<ValT>>(lhs, Constant<ValT>(rhs)); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator-(const ValT& lhs, const ptr<Expr<ValT>>& rhs) | |
{ | |
return alloc<SubExpr<ValT>>(Constant<ValT>(lhs), rhs); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator*(const ptr<Expr<ValT>>& lhs, const ptr<Expr<ValT>>& rhs) | |
{ | |
return alloc<MulExpr<ValT>>(lhs, rhs); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator*(const ptr<Expr<ValT>>& lhs, const ValT& rhs) | |
{ | |
return alloc<MulExpr<ValT>>(lhs, Constant<ValT>(rhs)); | |
} | |
template<typename ValT> | |
ptr<Expr<ValT>> | |
operator*(const ValT& lhs, const ptr<Expr<ValT>>& rhs) | |
{ | |
return alloc<MulExpr<ValT>>(Constant<ValT>(lhs), rhs); | |
} | |
#define DECL_DIFF_OP(Op, OpType) \ | |
template<typename ValT>\ | |
ptr<Expr<ValT>> operator##Op##(const ptr<Expr<ValT>> & lhs, const ptr<Expr<ValT>> & rhs)\ | |
{\ | |
return alloc<OpType<ValT>>(lhs, rhs);\ | |
}\ | |
template<typename ValT>\ | |
ptr<Expr<ValT>> operator##Op##(const ptr<Expr<ValT>> & lhs, const ValT & rhs)\ | |
{\ | |
return alloc<OpType<ValT>>(lhs, Constant<ValT>(rhs));\ | |
}\ | |
template<typename ValT>\ | |
ptr<Expr<ValT>> operator##Op##(const ValT & lhs, const ptr<Expr<ValT>> & rhs)\ | |
{\ | |
return alloc<OpType<ValT>>(Constant<ValT>(lhs), rhs);\ | |
} | |
DECL_DIFF_OP(+, AddExpr); | |
DECL_DIFF_OP(*, MulExpr); | |
DECL_DIFF_OP(/ , DivExpr); | |
DECL_DIFF_OP(-, SubExpr); | |
#undef DECL_DIFF_OP | |
template<typename ValT> | |
using Diff = ptr<Expr<ValT>>; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment