Last active
November 6, 2022 15:03
-
-
Save nthery/5e00760f280ceee2ea9698c50c1f5271 to your computer and use it in GitHub Desktop.
Expression tree represented by a std::variant
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// std::variant expression tree | |
#include <functional> | |
#include <iostream> | |
#include <memory> | |
#include <sstream> | |
#include <stdexcept> | |
#include <string> | |
#include <type_traits> | |
#include <unordered_map> | |
#include <variant> | |
// Types representing an expression tree. | |
// Non-leaf nodes must be forward-declared as their full definition refer to | |
// the variant. | |
using IntLiteral = int; | |
using Variable = std::string; | |
struct Add; | |
struct Sub; | |
struct Mul; | |
struct Div; | |
using Expr = std::variant<Add, Sub, Mul, Div, IntLiteral, Variable>; | |
// CRTP base class for binary operator nodes. | |
template<class Concrete> | |
struct BinaryOperator { | |
std::unique_ptr<Expr> lhs; | |
std::unique_ptr<Expr> rhs; | |
// Factory. | |
// Must be forward-declared because its body requires full definition of | |
// concrete type. | |
static Concrete make(Expr lhs, Expr rhs); | |
}; | |
// Concrete binary operator nodes. | |
struct Add : BinaryOperator<Add> {}; | |
struct Sub : BinaryOperator<Sub> {}; | |
struct Mul : BinaryOperator<Mul> {}; | |
struct Div : BinaryOperator<Div> {}; | |
template<class Concrete> | |
Concrete BinaryOperator<Concrete>::make(Expr lhs, Expr rhs) { | |
return Concrete{{std::make_unique<Expr>(std::move(lhs)), | |
std::make_unique<Expr>(std::move(rhs))}}; | |
} | |
using Environment = std::unordered_map<Variable, IntLiteral>; | |
// Traits to factor out shared logic in visitors. | |
// Kept separate from expression tree types to emphasize that visitors are | |
// independant. | |
template<class T> | |
struct BinaryOperatorTraits; | |
template<> | |
struct BinaryOperatorTraits<Add> { | |
using EvalExpr = std::plus<IntLiteral>; | |
static constexpr char asString[] = "+"; | |
}; | |
template<> | |
struct BinaryOperatorTraits<Sub> { | |
using EvalExpr = std::minus<IntLiteral>; | |
static constexpr char asString[] = "-"; | |
}; | |
template<> | |
struct BinaryOperatorTraits<Mul> { | |
using EvalExpr = std::multiplies<IntLiteral>; | |
static constexpr char asString[] = "*"; | |
}; | |
template<> | |
struct BinaryOperatorTraits<Div> { | |
using EvalExpr = std::divides<IntLiteral>; | |
static constexpr char asString[] = "/"; | |
}; | |
// Expression evaluator. | |
struct Evaluator { | |
struct Error : std::runtime_error { | |
explicit Error(const std::string& msg) : std::runtime_error(msg) {} | |
}; | |
static IntLiteral eval(const Environment& env, const Expr& expr) { | |
return std::visit(Evaluator(env), expr); | |
} | |
explicit Evaluator(const Environment& env) : env_(env) {} | |
IntLiteral operator()(const IntLiteral& n) { return n; } | |
IntLiteral operator()(const Variable& var) { | |
if (auto it = env_.find(var); it != env_.end()) { | |
return it->second; | |
} | |
std::ostringstream oss; | |
oss << "Unknown variable " << var; | |
throw Error(oss.str()); | |
} | |
template<class T> | |
std::enable_if_t<std::is_base_of_v<BinaryOperator<T>, T>, IntLiteral> | |
operator()(const T& op) { | |
auto evalExpr = typename BinaryOperatorTraits<T>::EvalExpr(); | |
return evalExpr(std::visit(*this, *op.lhs), std::visit(*this, *op.rhs)); | |
} | |
IntLiteral operator()(const Div& div) { | |
const auto lhsEval = std::visit(*this, *div.lhs); | |
const auto rhsEval = std::visit(*this, *div.rhs); | |
if (rhsEval == 0) { | |
throw Error("division by zero"); | |
} | |
return lhsEval / rhsEval; | |
} | |
private: | |
Environment env_; | |
}; | |
// Expression not-so-pretty printer. | |
struct Printer { | |
static void print(std::ostream& os, const Expr& expr) { | |
std::visit(Printer(os), expr); | |
} | |
explicit Printer(std::ostream& os) : os_(os) {} | |
void operator()(const IntLiteral& n) { os_ << n; } | |
void operator()(const Variable& var) { os_ << var; } | |
template<class T> | |
std::enable_if_t<std::is_base_of_v<BinaryOperator<T>, T>, void> | |
operator()(const T& op) { | |
os_ << '('; | |
std::visit(*this, *op.lhs); | |
os_ << ' ' << BinaryOperatorTraits<T>::asString << ' '; | |
std::visit(*this, *op.rhs); | |
os_ << ')'; | |
} | |
private: | |
std::ostream& os_; | |
}; | |
int main() { | |
try { | |
const Expr e = Add::make(Mul::make("bar", 5), Sub::make("foo", Div::make(6, "baz"))); | |
Environment env = { { "foo", 42 }, { "bar", 3 }, { "baz", 2 } }; | |
Printer::print(std::cout, e); | |
std::cout << " = " << Evaluator::eval(env, e) << '\n'; | |
} catch (const Evaluator::Error& e) { | |
std::cerr << "\nFatal error during evaluation: " << e.what() << '\n'; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment