Skip to content

Instantly share code, notes, and snippets.

@nthery
Last active November 6, 2022 15:03
Show Gist options
  • Save nthery/5e00760f280ceee2ea9698c50c1f5271 to your computer and use it in GitHub Desktop.
Save nthery/5e00760f280ceee2ea9698c50c1f5271 to your computer and use it in GitHub Desktop.
Expression tree represented by a std::variant
// std::variant expression tree
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <variant>
// Types representing an expression tree.
// Non-leaf nodes must be forward-declared as their full definition refer to
// the variant.
using IntLiteral = int;
using Variable = std::string;
struct Add;
struct Sub;
struct Mul;
struct Div;
using Expr = std::variant<Add, Sub, Mul, Div, IntLiteral, Variable>;
// CRTP base class for binary operator nodes.
template<class Concrete>
struct BinaryOperator {
std::unique_ptr<Expr> lhs;
std::unique_ptr<Expr> rhs;
// Factory.
// Must be forward-declared because its body requires full definition of
// concrete type.
static Concrete make(Expr lhs, Expr rhs);
};
// Concrete binary operator nodes.
struct Add : BinaryOperator<Add> {};
struct Sub : BinaryOperator<Sub> {};
struct Mul : BinaryOperator<Mul> {};
struct Div : BinaryOperator<Div> {};
template<class Concrete>
Concrete BinaryOperator<Concrete>::make(Expr lhs, Expr rhs) {
return Concrete{{std::make_unique<Expr>(std::move(lhs)),
std::make_unique<Expr>(std::move(rhs))}};
}
using Environment = std::unordered_map<Variable, IntLiteral>;
// Traits to factor out shared logic in visitors.
// Kept separate from expression tree types to emphasize that visitors are
// independant.
template<class T>
struct BinaryOperatorTraits;
template<>
struct BinaryOperatorTraits<Add> {
using EvalExpr = std::plus<IntLiteral>;
static constexpr char asString[] = "+";
};
template<>
struct BinaryOperatorTraits<Sub> {
using EvalExpr = std::minus<IntLiteral>;
static constexpr char asString[] = "-";
};
template<>
struct BinaryOperatorTraits<Mul> {
using EvalExpr = std::multiplies<IntLiteral>;
static constexpr char asString[] = "*";
};
template<>
struct BinaryOperatorTraits<Div> {
using EvalExpr = std::divides<IntLiteral>;
static constexpr char asString[] = "/";
};
// Expression evaluator.
struct Evaluator {
struct Error : std::runtime_error {
explicit Error(const std::string& msg) : std::runtime_error(msg) {}
};
static IntLiteral eval(const Environment& env, const Expr& expr) {
return std::visit(Evaluator(env), expr);
}
explicit Evaluator(const Environment& env) : env_(env) {}
IntLiteral operator()(const IntLiteral& n) { return n; }
IntLiteral operator()(const Variable& var) {
if (auto it = env_.find(var); it != env_.end()) {
return it->second;
}
std::ostringstream oss;
oss << "Unknown variable " << var;
throw Error(oss.str());
}
template<class T>
std::enable_if_t<std::is_base_of_v<BinaryOperator<T>, T>, IntLiteral>
operator()(const T& op) {
auto evalExpr = typename BinaryOperatorTraits<T>::EvalExpr();
return evalExpr(std::visit(*this, *op.lhs), std::visit(*this, *op.rhs));
}
IntLiteral operator()(const Div& div) {
const auto lhsEval = std::visit(*this, *div.lhs);
const auto rhsEval = std::visit(*this, *div.rhs);
if (rhsEval == 0) {
throw Error("division by zero");
}
return lhsEval / rhsEval;
}
private:
Environment env_;
};
// Expression not-so-pretty printer.
struct Printer {
static void print(std::ostream& os, const Expr& expr) {
std::visit(Printer(os), expr);
}
explicit Printer(std::ostream& os) : os_(os) {}
void operator()(const IntLiteral& n) { os_ << n; }
void operator()(const Variable& var) { os_ << var; }
template<class T>
std::enable_if_t<std::is_base_of_v<BinaryOperator<T>, T>, void>
operator()(const T& op) {
os_ << '(';
std::visit(*this, *op.lhs);
os_ << ' ' << BinaryOperatorTraits<T>::asString << ' ';
std::visit(*this, *op.rhs);
os_ << ')';
}
private:
std::ostream& os_;
};
int main() {
try {
const Expr e = Add::make(Mul::make("bar", 5), Sub::make("foo", Div::make(6, "baz")));
Environment env = { { "foo", 42 }, { "bar", 3 }, { "baz", 2 } };
Printer::print(std::cout, e);
std::cout << " = " << Evaluator::eval(env, e) << '\n';
} catch (const Evaluator::Error& e) {
std::cerr << "\nFatal error during evaluation: " << e.what() << '\n';
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment