Skip to content

Instantly share code, notes, and snippets.

@smunix
Last active October 31, 2021 02:31
Show Gist options
  • Save smunix/15c4b6c5a4bb7e917e97b3085e5a2bc7 to your computer and use it in GitHub Desktop.
Save smunix/15c4b6c5a4bb7e917e97b3085e5a2bc7 to your computer and use it in GitHub Desktop.
Arithmetic Expressions
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
module Arithmetic.Expr where
import Data.Char (chr)
import GHC.Real (RealFrac (truncate))
{-
Expr :: Term
Term :: Int
| Factor + Factor
Factor :: Term
| Factor x Factor
-}
{-
data Term
= MkPrim Int
| MkAdd Factor Factor
deriving (Show, Eq)
-}
data Term where
MkPrim :: Int -> Term
MkAdd :: Factor -> Factor -> Term
deriving (Show, Eq)
data TermG a where
MkPrimG :: a -> TermG a
MkAddG :: FactorG a -> FactorG a -> TermG a
deriving (Show, Eq)
instance Functor TermG where
fmap f (MkPrimG a) = MkPrimG (f a)
fmap f (MkAddG l r) = MkAddG (fmap f l) (fmap f r)
{- d = MkPrimG 1.0
i = MkPrimG 2
s = MkPrimG "Mo"
-}
fnL :: forall a. [] a -> Maybe a
fnL [] = Nothing
fnL (a : as) = Just a
fnG :: forall a. (Num a) => TermG a -> a
fnG (MkPrimG a) = a
fnG (MkAddG l r) = evalF l + evalF r
evalF :: forall a. Num a => FactorG a -> a
evalF (MkTermG t) = fnG t
evalF (MkMulG l r) = evalF l * evalF r
{-
data Factor
= MkTerm Term
| MkMul Factor Factor
deriving (Show, Eq)
-}
data Factor where
MkTerm :: Term -> Factor
MkMul :: Factor -> Factor -> Factor
deriving (Show, Eq)
data FactorG a where
MkTermG :: TermG a -> FactorG a
MkMulG :: FactorG a -> FactorG a -> FactorG a
deriving (Show, Eq)
instance Functor FactorG where
fmap f (MkTermG t) = MkTermG (fmap f t)
fmap f (MkMulG l r) = MkMulG (fmap f l) (fmap f r)
type Expr = Term
type ExprG a = TermG a
{- e = 1 + 2 x 3 -}
{-
+
1 x
2 3
-}
e :: Expr
e = MkAdd f1 f2
where
f2 :: Factor
f2 = MkMul (MkTerm (MkPrim 2)) (MkTerm (MkPrim 3))
f1 :: Factor
f1 = MkTerm (MkPrim 1)
-- $> import Xiaoming.Arithmetic.Expr
-- $> e
eg :: forall a. (a ~ Float) => ExprG a
eg = MkAddG f1 f2
where
f2 :: FactorG a
f2 = MkMulG (MkTermG (MkPrimG 2)) (MkTermG (MkPrimG 3))
f1 :: FactorG a
f1 = MkTermG (MkPrimG 1)
-- $> import Arithmetic.Expr
-- $> (fnG eg, eg)
-- $> import Data.Char
-- $> (fmap (chr) . (fmap truncate) $ eg, eg)
-- $> (fmap (chr . truncate) $ eg, eg)
{-
BNF Grammar representation (Backus Nor Form)
Expr : Expr + Term
| Expr - Term
| Term
Term : Term * Factor
| Term / Factor
| Factor
Factor : Int
-}
data Factor where
Int :: Int -> Factor
deriving (Show)
data Term where
(:*:) :: Term -> Factor -> Term
(:/:) :: Term -> Factor -> Term
Factor :: Factor -> Term
deriving (Show)
data Expr where
(:+:) :: Expr -> Term -> Expr
(:-:) :: Expr -> Term -> Expr
Term :: Term -> Expr
deriving (Show)
-- 1 + 2 x 3
v :: Expr
v = Term (Factor (Int 1)) :+: (Factor (Int 2) :*: Int 3)
-- $> v
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
module Regex.Expr where
{-
'a'
'ab' <=> ReConcat (ReChar 'a') (ReChar 'b') <=> 'a' * 'b'
'abc' <=> ReConcat (ReChar 'a') (ReConcat (ReChar 'b') (ReChar 'c'))<=> 'a' * 'b'
'a|b' <=> 'a' + 'b'
'a*'
Kleene Star
\epsilon
-}
data RegExpr a where
ReChar :: a -> RegExpr a
ReEpsilon :: RegExpr a
ReConcat :: RegExpr a -> RegExpr a -> RegExpr a
ReChoice :: RegExpr a -> RegExpr a -> RegExpr a
ReStar :: RegExpr a -> RegExpr a
data E a where
Var :: a -> E a
Fun :: a -> E a -> E a
App :: E a -> E a -> E a
data Ty a where
PrimTy :: a -> Ty a
TempTy :: a -> Ty a -> Ty a
AppTy :: Ty a -> Ty a -> Ty a
@mo-xiaoming
Copy link

I think the following code should have the similar functionality, but much verbose, due to following reasons

  • C++ doesn't allow virtual template member functions, so the return type of fmap is a pain in the a**. Have to use some generic types, like std::any or std::variant
  • The order of declaration matters, not be able to use a type first, then declare it later
  • I swear there is a third one, I can't remember...
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <variant>

template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template <class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

struct [[nodiscard]] Int {
  Int(int i = 0) : value_(i) {}
  int value() const noexcept { return value_; }

  friend Int operator+(Int lhs, Int rhs) noexcept { return {lhs.value() + rhs.value()}; }
  friend Int operator*(Int lhs, Int rhs) noexcept { return {lhs.value() * rhs.value()}; }

private:
  const int value_{};
};

struct [[nodiscard]] Double {
  Double(double d = 0) : value_(d) {}
  double value() const noexcept { return value_; }

  friend Double operator+(Double lhs, Double rhs) noexcept { return {lhs.value() + rhs.value()}; }
  friend Double operator*(Double lhs, Double rhs) noexcept { return {lhs.value() * rhs.value()}; }

private:
  const double value_{};
};

struct [[nodiscard]] String {
  String(std::string s = {}) : value_(s) {}
  const std::string &value() const &noexcept { return value_; }
  std::string value() const && { return value_; }

private:
  const std::string value_;
};

using BasicType = std::variant<Int, Double, String>;

std::ostream &operator<<(std::ostream &os, const BasicType &v) {
  return std::visit([&os](const auto &a) -> std::ostream & { return os << a.value(); }, v);
}

[[nodiscard]] BasicType operator+(const BasicType &lhs, const BasicType &rhs) {
  return std::visit(
      overloaded{[](Int l, Int r) -> BasicType { return Int(l.value() + r.value()); },
                 [](Double l, Double r) -> BasicType { return Double(l.value() + r.value()); },
                 [](const String &l, const String &r) -> BasicType {
                   std::ostringstream oss;
                   oss << l.value() << " + " << r.value();
                   return String(oss.str());
                 },
                 [](auto l, auto r) -> BasicType { return l + r; }},
      lhs, rhs);
}

[[nodiscard]] BasicType operator*(const BasicType &lhs, const BasicType &rhs) {
  return std::visit(
      overloaded{[](Int l, Int r) -> BasicType { return Int(l.value() * r.value()); },
                 [](Double l, Double r) -> BasicType { return Double(l.value() * r.value()); },
                 [](const String &l, const String &r) -> BasicType {
                   std::ostringstream oss;
                   oss << l.value() << " * " << r.value();
                   return String(oss.str());
                 },
                 [](auto l, auto r) -> BasicType { return l + r; }},
      lhs, rhs);
}

template <typename T> using Ptr = std::unique_ptr<T>;

template <typename T, typename... Args> [[nodiscard]] Ptr<T> MkPtr(Args &&...args) {
  return std::make_unique<T>(std::forward<Args>(args)...);
}

struct [[nodiscard]] Expr {
  Expr() = default;
  Expr(const Expr &) = delete;
  Expr(Expr &&) = delete;
  auto operator=(Expr &&) &noexcept -> Expr & = default;
  auto operator=(const Expr &) & -> Expr & = default;
  virtual ~Expr() = default;

  using TransformFn = std::function<BasicType(const BasicType &)>;
  virtual auto result(TransformFn &&tf = {}) const -> BasicType = 0;

  virtual auto toString() const -> std::string = 0;
};

struct [[nodiscard]] Term : Expr {};

struct [[nodiscard]] Factor : Expr {};

struct [[nodiscard]] TermPrim : Term {
  explicit TermPrim(BasicType t) : value(std::move(t)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return fn ? std::forward<TransformFn>(fn)(value) : value;
  }

  auto toString() const -> std::string override {
    std::ostringstream oss;
    oss << "(Prim " << value << ")";
    return oss.str();
  }

private:
  BasicType value;
};

struct [[nodiscard]] TermAdd : Term {
  TermAdd(Ptr<const Factor> lhs, Ptr<const Factor> rhs)
      : lhs(std::move(lhs)), rhs(std::move(rhs)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return lhs->result(std::forward<TransformFn>(fn)) + rhs->result(std::forward<TransformFn>(fn));
  }

  auto toString() const -> std::string override {
    std::ostringstream oss;
    oss << "(Add " << lhs->toString() << ' ' << rhs->toString() << ")";
    return oss.str();
  }

private:
  Ptr<const Factor> lhs;
  Ptr<const Factor> rhs;
};

struct [[nodiscard]] FactorTerm : Factor {
  explicit FactorTerm(Ptr<const Term> t) : term(std::move(t)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return term->result(std::forward<TransformFn>(fn));
  }

  auto toString() const -> std::string override {
    std::ostringstream oss;
    oss << "(Factor " << term->toString() << ")";
    return oss.str();
  }

private:
  Ptr<const Term> term;
};

struct [[nodiscard]] FactorMul : Factor {
  FactorMul(Ptr<const Factor> lhs, Ptr<const Factor> rhs)
      : lhs(std::move(lhs)), rhs(std::move(rhs)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return lhs->result(std::forward<TransformFn>(fn)) * rhs->result(std::forward<TransformFn>(fn));
  }

  auto toString() const -> std::string override {
    std::ostringstream oss;
    oss << "(Mul " << lhs->toString() << ' ' << rhs->toString() << ")";
    return oss.str();
  }

private:
  Ptr<const Factor> lhs;
  Ptr<const Factor> rhs;
};

[[nodiscard]] Ptr<const Expr> createAST() {
  Ptr<const Factor> one = MkPtr<const FactorTerm>(MkPtr<const TermPrim>(Double(1.1)));
  Ptr<const Factor> two = MkPtr<const FactorTerm>(MkPtr<const TermPrim>(Double(2.2)));
  Ptr<const Factor> thr = MkPtr<const FactorTerm>(MkPtr<const TermPrim>(Double(3.3)));

  Ptr<const Factor> mul = MkPtr<const FactorMul>(std::move(two), std::move(thr));
  Ptr<const Term> add = MkPtr<const TermAdd>(std::move(one), std::move(mul));

  return add;
}

auto main() -> int {
  const auto ast = createAST();

  std::cout << ast->toString() << '\n';
  std::cout << ast->result() << '\n';

  std::cout << "\n\n-- to int type\n";
  std::cout << "(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = ";
  std::cout << ast->result([](const BasicType &v) -> BasicType {
    const auto d = std::get<Double>(v).value();
    return Int(std::floor(d * d));
  }) << '\n';

  std::cout << "\n\n-- to string type\n";
  std::cout << ast->result([](const BasicType &v) -> BasicType {
    const auto d = std::get<Double>(v).value();
    return String(std::to_string(d));
  }) << '\n';
}

@mo-xiaoming
Copy link

#### 1 + 2 * 3
(+ (Expr (Term (Factor 1.1))) (* (Term (Factor 2.2)) (Factor 3.3)))
8.36


-- to int type
(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = 41


-- to string type
1.100000 + 2.200000 * 3.300000



#### 1 - 2 * 0
(- (Expr (Term (Factor 1))) (/ (Term (Factor 2)) (Factor 0)))
Error: DivByZero
#include <array>
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <string_view>
#include <utility>
#include <variant>

using namespace std::literals;

template <typename... Args> [[nodiscard]] auto format(Args &&...args) -> std::string {
  std::ostringstream oss;
  (oss << ... << args);
  return oss.str();
}

template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template <class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

struct [[nodiscard]] Error {
  enum Type : std::int8_t { NoIdea, DivByZero };

  explicit Error(Type t = NoIdea) : value_(t) {}
  [[nodiscard]] auto value() const noexcept -> std::string_view { return typeToString(); }

private:
  [[nodiscard]] auto typeToString() const -> std::string_view {
    static constexpr std::array ts{"Error: I have no idea what's going on here"sv,
                                   "Error: DivByZero"sv};
    return ts[static_cast<std::size_t>(value_)];
  }
  const Type value_ = NoIdea;
};

struct [[nodiscard]] Int {
  explicit Int(int i = 0) : value_(i) {}
  [[nodiscard]] auto value() const noexcept -> int { return value_; }

private:
  const int value_{};
};

struct [[nodiscard]] Double {
  explicit Double(double d = 0) : value_(d) {}
  [[nodiscard]] auto value() const noexcept -> double { return value_; }

private:
  const double value_{};
};

struct [[nodiscard]] String {
  explicit String(std::string s = {}) : value_(std::move(s)) {}
  [[nodiscard]] auto value() const &noexcept -> const std::string & { return value_; }
  [[nodiscard]] auto value() const && -> std::string { return value_; }

private:
  const std::string value_;
};

using BasicType = std::variant<Error, Int, Double, String>;

auto operator<<(std::ostream &os, const BasicType &v) -> std::ostream & {
  return std::visit([&os](const auto &a) -> std::ostream & { return os << a.value(); }, v);
}

[[nodiscard]] auto operator+(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit(
      [](const auto &l, const auto &r) -> BasicType {
        using LT = std::decay_t<decltype(l)>;
        using RT = std::decay_t<decltype(r)>;
        if constexpr (std::is_same_v<LT, RT>) {
          if constexpr (std::is_same_v<LT, Int>) {
            return Int(l.value() + r.value());
          }
          if constexpr (std::is_same_v<LT, Double>) {
            return Double(l.value() + r.value());
          }
          if constexpr (std::is_same_v<LT, String>) {
            return String(format(l.value(), " + ", r.value()));
          }
        } else {
          if constexpr (std::is_same_v<LT, Error>) {
            return l;
          }
          if constexpr (std::is_same_v<RT, Error>) {
            return r;
          }
        }
        return l + r;
      },
      lhs, rhs);
}

[[nodiscard]] auto operator-(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit(
      [](const auto &l, const auto &r) -> BasicType {
        using LT = std::decay_t<decltype(l)>;
        using RT = std::decay_t<decltype(r)>;
        if constexpr (std::is_same_v<LT, RT>) {
          if constexpr (std::is_same_v<LT, Int>) {
            return Int(l.value() - r.value());
          }
          if constexpr (std::is_same_v<LT, Double>) {
            return Double(l.value() - r.value());
          }
          if constexpr (std::is_same_v<LT, String>) {
            return String(format(l.value(), " - ", r.value()));
          }
        } else {
          if constexpr (std::is_same_v<LT, Error>) {
            return l;
          }
          if constexpr (std::is_same_v<RT, Error>) {
            return r;
          }
        }
        return l - r;
      },
      lhs, rhs);
}

[[nodiscard]] auto operator*(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit(
      [](const auto &l, const auto &r) -> BasicType {
        using LT = std::decay_t<decltype(l)>;
        using RT = std::decay_t<decltype(r)>;
        if constexpr (std::is_same_v<LT, RT>) {
          if constexpr (std::is_same_v<LT, Int>) {
            return Int(l.value() * r.value());
          }
          if constexpr (std::is_same_v<LT, Double>) {
            return Double(l.value() * r.value());
          }
          if constexpr (std::is_same_v<LT, String>) {
            return String(format(l.value(), " * ", r.value()));
          }
        } else {
          if constexpr (std::is_same_v<LT, Error>) {
            return l;
          }
          if constexpr (std::is_same_v<RT, Error>) {
            return r;
          }
        }
        return l * r;
      },
      lhs, rhs);
}

[[nodiscard]] auto operator/(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit(
      [](const auto &l, const auto &r) -> BasicType {
        using LT = std::decay_t<decltype(l)>;
        using RT = std::decay_t<decltype(r)>;
        if constexpr (std::is_same_v<LT, RT>) {
          if constexpr (std::is_same_v<LT, Int>) {
            if (r.value() == 0) {
              return Error(Error::Type::DivByZero);
            }
            return Int(l.value() / r.value());
          }
          if constexpr (std::is_same_v<LT, Double>) {
            return Double(l.value() / r.value());
          }
          if constexpr (std::is_same_v<LT, String>) {
            return String(format(l.value(), " / ", r.value()));
          }
        } else {
          if constexpr (std::is_same_v<LT, Error>) {
            return l;
          }
          if constexpr (std::is_same_v<RT, Error>) {
            return r;
          }
        }
        return l * r;
      },
      lhs, rhs);
}

template <typename T> using Ptr = std::unique_ptr<T>;
template <typename T, typename... Args> [[nodiscard]] auto MkPtr(Args &&...args) -> Ptr<T> {
  return std::make_unique<T>(std::forward<Args>(args)...);
}

struct [[nodiscard]] Base {
  Base() = default;
  Base(const Base &) = delete;
  Base(Base &&) = delete;
  auto operator=(Base &&) &noexcept -> Base & = default;
  auto operator=(const Base &) & -> Base & = default;
  virtual ~Base() = default;

  using TransformFn = std::function<BasicType(const BasicType &)>;
  virtual auto result(TransformFn &&tf = {}) const -> BasicType = 0;

  [[nodiscard]] virtual auto toString() const -> std::string = 0;
};

struct [[nodiscard]] Factor : Base {
  explicit Factor(BasicType v) : value_(std::move(v)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return fn ? std::forward<TransformFn>(fn)(value_) : value_;
  }

  [[nodiscard]] auto toString() const -> std::string override {
    return format("(Factor ", value_, ")");
  }

private:
  const BasicType value_;
};

template <typename T> struct [[nodiscard]] BaseOp {
  auto result(Base::TransformFn &&fn) const -> BasicType {
    return ref()->act(ref()->lhs_->result(std::forward<Base::TransformFn>(fn)),
                      ref()->rhs_->result(std::forward<Base::TransformFn>(fn)));
  }

  [[nodiscard]] auto toString() const -> std::string {
    return format('(', ref()->opToString(), ' ', ref()->lhs_->toString(), ' ',
                  ref()->rhs_->toString(), ')');
  }

private:
  [[nodiscard]] auto ref() const -> const T * { return static_cast<const T *>(this); }
};

struct [[nodiscard]] Term : Base {};

struct [[nodiscard]] TermBinOp : Term, BaseOp<TermBinOp> {
  enum class Op : std::int8_t { Mul, Div };

  TermBinOp(Op op, Ptr<const Term> lhs, Ptr<const Factor> rhs)
      : op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return BaseOp::result(std::forward<TransformFn>(fn));
  }

  [[nodiscard]] auto toString() const -> std::string override { return BaseOp::toString(); }

private:
  [[nodiscard]] auto act(const BasicType &l, const BasicType &r) const -> BasicType {
    switch (op_) {
    case Op::Mul:
      return l * r;
    case Op::Div:
      return l / r;
    }
    return {}; // shouldn't reach here
  }

  [[nodiscard]] auto opToString() const -> std::string_view {
    static constexpr std::array os{"*"sv, "/"sv};
    return os[static_cast<std::size_t>(op_)];
  }

  friend BaseOp<TermBinOp>;
  const Op op_;
  const Ptr<const Term> lhs_;
  const Ptr<const Factor> rhs_;
};

struct [[nodiscard]] TermFactor : Term {
  explicit TermFactor(Ptr<const Factor> v) : factor_(std::move(v)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return factor_->result(std::forward<TransformFn>(fn));
  }

  [[nodiscard]] auto toString() const -> std::string override {
    return format("(Term ", factor_->toString(), ')');
  }

private:
  const Ptr<const Factor> factor_;
};

struct [[nodiscard]] Expr : Base {};

struct [[nodiscard]] ExprBinOp : Expr, BaseOp<ExprBinOp> {
  enum class Op : std::int8_t { Add, Sub };

  ExprBinOp(Op op, Ptr<const Expr> lhs, Ptr<const Term> rhs)
      : op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return BaseOp::result(std::forward<TransformFn>(fn));
  }

  [[nodiscard]] auto toString() const -> std::string override { return BaseOp::toString(); }

private:
  [[nodiscard]] auto act(const BasicType &l, const BasicType &r) const -> BasicType {
    switch (op_) {
    case Op::Add:
      return l + r;
    case Op::Sub:
      return l - r;
    }
    return {}; // shouldn't reach here
  }

  [[nodiscard]] auto opToString() const -> std::string_view {
    static constexpr std::array os{"+"sv, "-"sv};
    return os[static_cast<std::size_t>(op_)];
  }

  friend BaseOp<ExprBinOp>;
  const Op op_;
  const Ptr<const Expr> lhs_;
  const Ptr<const Term> rhs_;
};

struct [[nodiscard]] ExprTerm : Expr {
  explicit ExprTerm(Ptr<const Term> v) : factor_(std::move(v)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return factor_->result(std::forward<TransformFn>(fn));
  }

  [[nodiscard]] auto toString() const -> std::string override {
    return format("(Expr ", factor_->toString(), ')');
  }

private:
  const Ptr<const Term> factor_;
};

void happy() {
  const auto createAST = []() -> Ptr<const Expr> {
    Ptr<const Factor> one = MkPtr<const Factor>(Double(1.1));
    Ptr<const Factor> two = MkPtr<const Factor>(Double(2.2));
    Ptr<const Factor> thr = MkPtr<const Factor>(Double(3.3));

    Ptr<const Term> mul = MkPtr<const TermBinOp>(
        TermBinOp::Op::Mul, MkPtr<const TermFactor>(std::move(two)), std::move(thr));

    Ptr<const Expr> add = MkPtr<const ExprBinOp>(
        ExprBinOp::Op::Add, MkPtr<const ExprTerm>(MkPtr<const TermFactor>(std::move(one))),
        std::move(mul));
    return add;
  };

  const auto ast = createAST();

  std::cout << "#### 1 + 2 * 3\n";
  std::cout << ast->toString() << '\n';
  std::cout << ast->result() << '\n';

  std::cout << "\n\n-- to int type\n";
  std::cout << "(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = ";
  std::cout << ast->result([](const BasicType &v) -> BasicType {
    const auto d = std::get<Double>(v).value();
    return Int(std::floor(d * d));
  }) << '\n';

  std::cout << "\n\n-- to string type\n";
  std::cout << ast->result([](const BasicType &v) -> BasicType {
    const auto d = std::get<Double>(v).value();
    return String(std::to_string(d));
  }) << '\n';
}

void divByZero() {
  const auto createAST = []() -> Ptr<const Expr> {
    Ptr<const Factor> one = MkPtr<const Factor>(Int(1));
    Ptr<const Factor> two = MkPtr<const Factor>(Int(2));
    Ptr<const Factor> zer = MkPtr<const Factor>(Int(0));

    Ptr<const Term> div = MkPtr<const TermBinOp>(
        TermBinOp::Op::Div, MkPtr<const TermFactor>(std::move(two)), std::move(zer));

    Ptr<const Expr> add = MkPtr<const ExprBinOp>(
        ExprBinOp::Op::Sub, MkPtr<const ExprTerm>(MkPtr<const TermFactor>(std::move(one))),
        std::move(div));
    return add;
  };

  const auto ast = createAST();

  std::cout << "#### 1 - 2 * 0\n";
  std::cout << ast->toString() << '\n';
  std::cout << ast->result() << '\n';
}

auto main() -> int {
  happy();

  std::cout << "\n\n\n";

  divByZero();
}

@smunix
Copy link
Author

smunix commented Oct 30, 2021

This is great @mo-xiaoming, thank you for sharing. Quick question: how would you handle the possibility of this call to fail?

const auto d = std::get<Double>(v).value();

Ideally, I'd want the error handling logic to not pollute much. I'm curious to know what your thoughts are on error handling.

@smunix
Copy link
Author

smunix commented Oct 30, 2021

Btw, I love your result implementation. Isn't that a sort of fmapsimply? less generic though?

@mo-xiaoming
Copy link

This is great @mo-xiaoming, thank you for sharing. Quick question: how would you handle the possibility of this call to fail?

const auto d = std::get<Double>(v).value();

Ideally, I'd want the error handling logic to not pollute much. I'm curious to know what your thoughts are on error handling.

I was thinking when we create the AST, the underlining type is known to us. Using std::visit for pattern matching all input types is doable, but I think if we got it wrong, then it is a logic error, should not be handled by program.

And I'm just copying what you've done with haskell, otherwise I might simply throw exceptions at here and the point of div by zero. Because once those things happen, this is an ill formed program and there is no way for us to recover from these errors, might well just fail immediately, at least for the trivial functionality we have now

The error handling in C++ has been quite a hot topic among the community for several years, because nobody like what we have today. Popular ways are all flawed, for current c++ language, exceptions must not be thrown in hot path, normal error returns or errno can be silently ignored (something like LLVM::Error can solve this problem), std::optional is conceptually equivalent to Maybe, but the implementation in std makes it no different compare to nullptr, and it is not composable

I'm not a big fan of exception, but look at what's happened in Go when return errors are not composable, nested ifs make code hard to read

Beside Herb Sutters' effort to eliminate overhead of exceptions Zero-overhead deterministic exceptions: Throwing values.

There is a hope,

There is a paper, the name might interest you, it is called P0798R0: Monadic operations for std::optional, and there is an implementation on github makes optional much like a real Maybe

Btw, I love your result implementation. Isn't that a sort of fmap simply? less generic though?

Yes, it is sort of fmap, and yes, sadly it is much less generic.

Because template member functions cannot be virtual, that is, virtual template <typename Ret> Ret result() is disallowed, then some generic type, std::any, std::variant or something similar has to be used as input/output to eliminate the usage of template for virtual functions.

The consequence is, if std::variant is used, then there going to be some large std::visit or many nested if constexpr, sigh. They're ugly

They are talking about adding pattern matching to C++, but there is no time table, I'm afraid...

@smunix
Copy link
Author

smunix commented Oct 30, 2021

Cool stuff!

Fyi, the BNF is a little too verbose although it serves well its purpose (of describing unambiguous grammars). The Expr data type can be simplified as below:

data Exp
  = Exp :+: Exp
   | Exp :-: Exp
   | Exp :*: Exp
   | Exp :/: Exp
   | Int i

Hopefully this simplifies your C++ implementation as well.

@mo-xiaoming
Copy link

mo-xiaoming commented Oct 31, 2021

UPDATE I was wrong, I didn't consider if you were using Expr for everything, then haskell implementation would've been less than 10 lines, so now my C++ code is 30x longer than yours

#### 1 + 2 * 3
(+ (Prim 1.1) (* (Prim 2.2) (Prim 3.3)))
8.36


-- to int type
(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = 41



#### 1 - 2 / 0
(- (Prim 1) (/ (Prim 2) (Prim 0)))
Error: DivByZero: 2/0

Using Expr to replace Term and Factor indeed makes code a little bit shorter, and removed String type, because it doesn't make much sense at this moment

The current code is slightly shorter than 20x compare to your code, sigh, c++ is too verbose

#include <array>
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <string_view>
#include <utility>
#include <variant>

using namespace std::literals;

template <typename... Args> [[nodiscard]] auto format(Args &&...args) -> std::string {
  std::ostringstream oss;
  (oss << ... << args);
  return oss.str();
}

struct [[nodiscard]] Error {
  enum Type : std::int8_t { NoIdea, DivByZero };

  explicit Error(Type t = NoIdea) : value_(t) {}
  explicit Error(Type t, std::string msg) : value_(t), msg_(std::move(msg)) {}
  [[nodiscard]] auto value() const noexcept -> std::string {
    return format(typeToString(), ": ", msg_);
  }

private:
  [[nodiscard]] auto typeToString() const -> std::string_view {
    static constexpr std::array ts{"Error: I have no idea what's going on here"sv,
                                   "Error: DivByZero"sv};
    return ts[static_cast<std::size_t>(value_)];
  }
  const Type value_ = NoIdea;
  const std::string msg_;
};

struct [[nodiscard]] Int {
  explicit Int(int i = 0) : value_(i) {}
  [[nodiscard]] auto value() const noexcept -> int { return value_; }

private:
  const int value_{};
};

struct [[nodiscard]] Double {
  explicit Double(double d = 0) : value_(d) {}
  [[nodiscard]] auto value() const noexcept -> double { return value_; }

private:
  const double value_{};
};

using BasicType = std::variant<Error, Int, Double>;

auto operator<<(std::ostream &os, const BasicType &v) -> std::ostream & {
  return std::visit([&os](const auto &a) -> std::ostream & { return os << a.value(); }, v);
}

// cannot make them hidden friends, because BasicType is unknown inside these classes
[[nodiscard]] auto operator+(Int lhs, Int rhs) -> BasicType {
  return Int(lhs.value() + rhs.value());
}
[[nodiscard]] auto operator-(Int lhs, Int rhs) -> BasicType {
  return Int(lhs.value() - rhs.value());
}
[[nodiscard]] auto operator*(Int lhs, Int rhs) -> BasicType {
  return Int(lhs.value() * rhs.value());
}
[[nodiscard]] auto operator/(Int lhs, Int rhs) -> BasicType {
  if (rhs.value() == 0) {
    return Error(Error::Type::DivByZero, format(lhs.value(), '/', rhs.value()));
  }
  return Int(lhs.value() / rhs.value());
}

[[nodiscard]] auto operator+(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() + rhs.value());
}
[[nodiscard]] auto operator-(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() - rhs.value());
}
[[nodiscard]] auto operator*(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() * rhs.value());
}
[[nodiscard]] auto operator/(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() / rhs.value());
}

template <typename T>
[[nodiscard]] auto operator+(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator-(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator*(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator/(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator+(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator-(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator*(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator/(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator+(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator-(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator*(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator/(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}

[[nodiscard]] auto operator+(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l + r; }, lhs, rhs);
}

[[nodiscard]] auto operator-(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l - r; }, lhs, rhs);
}

[[nodiscard]] auto operator*(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l * r; }, lhs, rhs);
}

[[nodiscard]] auto operator/(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l / r; }, lhs, rhs);
}

template <typename T> using Ptr = std::unique_ptr<T>;
template <typename T, typename... Args> [[nodiscard]] auto MkPtr(Args &&...args) -> Ptr<T> {
  return std::make_unique<T>(std::forward<Args>(args)...);
}

struct [[nodiscard]] Expr {
  Expr() = default;
  Expr(const Expr &) = delete;
  Expr(Expr &&) = delete;
  auto operator=(Expr &&) &noexcept -> Expr & = default;
  auto operator=(const Expr &) & -> Expr & = default;
  virtual ~Expr() = default;

  using TransformFn = std::function<BasicType(const BasicType &)>;
  virtual auto result(TransformFn &&tf = {}) const -> BasicType = 0;

  [[nodiscard]] virtual auto toString() const -> std::string = 0;
};

struct [[nodiscard]] Prim : Expr {
  explicit Prim(BasicType v) : value_(std::move(v)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return fn ? std::forward<TransformFn>(fn)(value_) : value_;
  }

  [[nodiscard]] auto toString() const -> std::string override {
    return format("(Prim ", value_, ")");
  }

private:
  const BasicType value_;
};

struct [[nodiscard]] ExprBinOp : Expr {
  enum class Op : std::int8_t { Mul, Div, Add, Sub };

  ExprBinOp(Op op, Ptr<const Expr> lhs, Ptr<const Expr> rhs)
      : op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return act(lhs_->result(std::forward<TransformFn>(fn)),
               rhs_->result(std::forward<TransformFn>(fn)));
  }

  [[nodiscard]] auto toString() const -> std::string override {
    return format('(', opToString(), ' ', lhs_->toString(), ' ', rhs_->toString(), ')');
  }

private:
  [[nodiscard]] auto act(const BasicType &l, const BasicType &r) const -> BasicType {
    switch (op_) {
    case Op::Mul:
      return l * r;
    case Op::Div:
      return l / r;
    case Op::Add:
      return l + r;
    case Op::Sub:
      return l - r;
    }
    return {}; // shouldn't reach here
  }

  [[nodiscard]] auto opToString() const -> std::string_view {
    static constexpr std::array os{"*"sv, "/"sv, "+"sv, "-"sv};
    return os[static_cast<std::size_t>(op_)];
  }

  const Op op_;
  const Ptr<const Expr> lhs_;
  const Ptr<const Expr> rhs_;
};

void happy() {
  const auto createAST = []() -> Ptr<const Expr> {
    Ptr<const Expr> one = MkPtr<const Prim>(Double(1.1));
    Ptr<const Expr> two = MkPtr<const Prim>(Double(2.2));
    Ptr<const Expr> thr = MkPtr<const Prim>(Double(3.3));

    Ptr<const Expr> mul =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Mul, std::move(two), std::move(thr));

    Ptr<const Expr> add =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Add, std::move(one), std::move(mul));
    return add;
  };

  const auto ast = createAST();

  std::cout << "#### 1 + 2 * 3\n";
  std::cout << ast->toString() << '\n';
  std::cout << ast->result() << '\n';

  std::cout << "\n\n-- to int type\n";
  std::cout << "(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = ";
  std::cout << ast->result([](const BasicType &v) -> BasicType {
    const auto d = std::get<Double>(v).value();
    return Int(std::floor(d * d));
  }) << '\n';
}

void divByZero() {
  const auto createAST = []() -> Ptr<const Expr> {
    Ptr<const Expr> one = MkPtr<const Prim>(Int(1));
    Ptr<const Expr> two = MkPtr<const Prim>(Int(2));
    Ptr<const Expr> zer = MkPtr<const Prim>(Int(0));

    Ptr<const Expr> div =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Div, std::move(two), std::move(zer));

    Ptr<const Expr> sub =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Sub, std::move(one), std::move(div));
    return sub;
  };

  const auto ast = createAST();

  std::cout << "#### 1 - 2 / 0\n";
  std::cout << ast->toString() << '\n';
  std::cout << ast->result() << '\n';
}

auto main() -> int {
  happy();

  std::cout << "\n\n\n";

  divByZero();
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment