Skip to content

Instantly share code, notes, and snippets.

@smunix
Last active October 31, 2021 02:31
Show Gist options
  • Save smunix/15c4b6c5a4bb7e917e97b3085e5a2bc7 to your computer and use it in GitHub Desktop.
Save smunix/15c4b6c5a4bb7e917e97b3085e5a2bc7 to your computer and use it in GitHub Desktop.
Arithmetic Expressions
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
module Arithmetic.Expr where
import Data.Char (chr)
import GHC.Real (RealFrac (truncate))
{-
Expr :: Term
Term :: Int
| Factor + Factor
Factor :: Term
| Factor x Factor
-}
{-
data Term
= MkPrim Int
| MkAdd Factor Factor
deriving (Show, Eq)
-}
data Term where
MkPrim :: Int -> Term
MkAdd :: Factor -> Factor -> Term
deriving (Show, Eq)
data TermG a where
MkPrimG :: a -> TermG a
MkAddG :: FactorG a -> FactorG a -> TermG a
deriving (Show, Eq)
instance Functor TermG where
fmap f (MkPrimG a) = MkPrimG (f a)
fmap f (MkAddG l r) = MkAddG (fmap f l) (fmap f r)
{- d = MkPrimG 1.0
i = MkPrimG 2
s = MkPrimG "Mo"
-}
fnL :: forall a. [] a -> Maybe a
fnL [] = Nothing
fnL (a : as) = Just a
fnG :: forall a. (Num a) => TermG a -> a
fnG (MkPrimG a) = a
fnG (MkAddG l r) = evalF l + evalF r
evalF :: forall a. Num a => FactorG a -> a
evalF (MkTermG t) = fnG t
evalF (MkMulG l r) = evalF l * evalF r
{-
data Factor
= MkTerm Term
| MkMul Factor Factor
deriving (Show, Eq)
-}
data Factor where
MkTerm :: Term -> Factor
MkMul :: Factor -> Factor -> Factor
deriving (Show, Eq)
data FactorG a where
MkTermG :: TermG a -> FactorG a
MkMulG :: FactorG a -> FactorG a -> FactorG a
deriving (Show, Eq)
instance Functor FactorG where
fmap f (MkTermG t) = MkTermG (fmap f t)
fmap f (MkMulG l r) = MkMulG (fmap f l) (fmap f r)
type Expr = Term
type ExprG a = TermG a
{- e = 1 + 2 x 3 -}
{-
+
1 x
2 3
-}
e :: Expr
e = MkAdd f1 f2
where
f2 :: Factor
f2 = MkMul (MkTerm (MkPrim 2)) (MkTerm (MkPrim 3))
f1 :: Factor
f1 = MkTerm (MkPrim 1)
-- $> import Xiaoming.Arithmetic.Expr
-- $> e
eg :: forall a. (a ~ Float) => ExprG a
eg = MkAddG f1 f2
where
f2 :: FactorG a
f2 = MkMulG (MkTermG (MkPrimG 2)) (MkTermG (MkPrimG 3))
f1 :: FactorG a
f1 = MkTermG (MkPrimG 1)
-- $> import Arithmetic.Expr
-- $> (fnG eg, eg)
-- $> import Data.Char
-- $> (fmap (chr) . (fmap truncate) $ eg, eg)
-- $> (fmap (chr . truncate) $ eg, eg)
{-
BNF Grammar representation (Backus Nor Form)
Expr : Expr + Term
| Expr - Term
| Term
Term : Term * Factor
| Term / Factor
| Factor
Factor : Int
-}
data Factor where
Int :: Int -> Factor
deriving (Show)
data Term where
(:*:) :: Term -> Factor -> Term
(:/:) :: Term -> Factor -> Term
Factor :: Factor -> Term
deriving (Show)
data Expr where
(:+:) :: Expr -> Term -> Expr
(:-:) :: Expr -> Term -> Expr
Term :: Term -> Expr
deriving (Show)
-- 1 + 2 x 3
v :: Expr
v = Term (Factor (Int 1)) :+: (Factor (Int 2) :*: Int 3)
-- $> v
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
module Regex.Expr where
{-
'a'
'ab' <=> ReConcat (ReChar 'a') (ReChar 'b') <=> 'a' * 'b'
'abc' <=> ReConcat (ReChar 'a') (ReConcat (ReChar 'b') (ReChar 'c'))<=> 'a' * 'b'
'a|b' <=> 'a' + 'b'
'a*'
Kleene Star
\epsilon
-}
data RegExpr a where
ReChar :: a -> RegExpr a
ReEpsilon :: RegExpr a
ReConcat :: RegExpr a -> RegExpr a -> RegExpr a
ReChoice :: RegExpr a -> RegExpr a -> RegExpr a
ReStar :: RegExpr a -> RegExpr a
data E a where
Var :: a -> E a
Fun :: a -> E a -> E a
App :: E a -> E a -> E a
data Ty a where
PrimTy :: a -> Ty a
TempTy :: a -> Ty a -> Ty a
AppTy :: Ty a -> Ty a -> Ty a
@smunix
Copy link
Author

smunix commented Oct 30, 2021

Cool stuff!

Fyi, the BNF is a little too verbose although it serves well its purpose (of describing unambiguous grammars). The Expr data type can be simplified as below:

data Exp
  = Exp :+: Exp
   | Exp :-: Exp
   | Exp :*: Exp
   | Exp :/: Exp
   | Int i

Hopefully this simplifies your C++ implementation as well.

@mo-xiaoming
Copy link

mo-xiaoming commented Oct 31, 2021

UPDATE I was wrong, I didn't consider if you were using Expr for everything, then haskell implementation would've been less than 10 lines, so now my C++ code is 30x longer than yours

#### 1 + 2 * 3
(+ (Prim 1.1) (* (Prim 2.2) (Prim 3.3)))
8.36


-- to int type
(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = 41



#### 1 - 2 / 0
(- (Prim 1) (/ (Prim 2) (Prim 0)))
Error: DivByZero: 2/0

Using Expr to replace Term and Factor indeed makes code a little bit shorter, and removed String type, because it doesn't make much sense at this moment

The current code is slightly shorter than 20x compare to your code, sigh, c++ is too verbose

#include <array>
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <string_view>
#include <utility>
#include <variant>

using namespace std::literals;

template <typename... Args> [[nodiscard]] auto format(Args &&...args) -> std::string {
  std::ostringstream oss;
  (oss << ... << args);
  return oss.str();
}

struct [[nodiscard]] Error {
  enum Type : std::int8_t { NoIdea, DivByZero };

  explicit Error(Type t = NoIdea) : value_(t) {}
  explicit Error(Type t, std::string msg) : value_(t), msg_(std::move(msg)) {}
  [[nodiscard]] auto value() const noexcept -> std::string {
    return format(typeToString(), ": ", msg_);
  }

private:
  [[nodiscard]] auto typeToString() const -> std::string_view {
    static constexpr std::array ts{"Error: I have no idea what's going on here"sv,
                                   "Error: DivByZero"sv};
    return ts[static_cast<std::size_t>(value_)];
  }
  const Type value_ = NoIdea;
  const std::string msg_;
};

struct [[nodiscard]] Int {
  explicit Int(int i = 0) : value_(i) {}
  [[nodiscard]] auto value() const noexcept -> int { return value_; }

private:
  const int value_{};
};

struct [[nodiscard]] Double {
  explicit Double(double d = 0) : value_(d) {}
  [[nodiscard]] auto value() const noexcept -> double { return value_; }

private:
  const double value_{};
};

using BasicType = std::variant<Error, Int, Double>;

auto operator<<(std::ostream &os, const BasicType &v) -> std::ostream & {
  return std::visit([&os](const auto &a) -> std::ostream & { return os << a.value(); }, v);
}

// cannot make them hidden friends, because BasicType is unknown inside these classes
[[nodiscard]] auto operator+(Int lhs, Int rhs) -> BasicType {
  return Int(lhs.value() + rhs.value());
}
[[nodiscard]] auto operator-(Int lhs, Int rhs) -> BasicType {
  return Int(lhs.value() - rhs.value());
}
[[nodiscard]] auto operator*(Int lhs, Int rhs) -> BasicType {
  return Int(lhs.value() * rhs.value());
}
[[nodiscard]] auto operator/(Int lhs, Int rhs) -> BasicType {
  if (rhs.value() == 0) {
    return Error(Error::Type::DivByZero, format(lhs.value(), '/', rhs.value()));
  }
  return Int(lhs.value() / rhs.value());
}

[[nodiscard]] auto operator+(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() + rhs.value());
}
[[nodiscard]] auto operator-(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() - rhs.value());
}
[[nodiscard]] auto operator*(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() * rhs.value());
}
[[nodiscard]] auto operator/(Double lhs, Double rhs) -> BasicType {
  return Double(lhs.value() / rhs.value());
}

template <typename T>
[[nodiscard]] auto operator+(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator-(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator*(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator/(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
  return Error(Error::NoIdea, "two errors?");
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator+(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator-(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator*(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator/(Error lhs, const T & /*rhs*/) -> BasicType {
  return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator+(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator-(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator*(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator/(const T & /*lhs*/, Error rhs) -> BasicType {
  return rhs;
}

[[nodiscard]] auto operator+(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l + r; }, lhs, rhs);
}

[[nodiscard]] auto operator-(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l - r; }, lhs, rhs);
}

[[nodiscard]] auto operator*(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l * r; }, lhs, rhs);
}

[[nodiscard]] auto operator/(const BasicType &lhs, const BasicType &rhs) -> BasicType {
  return std::visit([](const auto &l, const auto &r) -> BasicType { return l / r; }, lhs, rhs);
}

template <typename T> using Ptr = std::unique_ptr<T>;
template <typename T, typename... Args> [[nodiscard]] auto MkPtr(Args &&...args) -> Ptr<T> {
  return std::make_unique<T>(std::forward<Args>(args)...);
}

struct [[nodiscard]] Expr {
  Expr() = default;
  Expr(const Expr &) = delete;
  Expr(Expr &&) = delete;
  auto operator=(Expr &&) &noexcept -> Expr & = default;
  auto operator=(const Expr &) & -> Expr & = default;
  virtual ~Expr() = default;

  using TransformFn = std::function<BasicType(const BasicType &)>;
  virtual auto result(TransformFn &&tf = {}) const -> BasicType = 0;

  [[nodiscard]] virtual auto toString() const -> std::string = 0;
};

struct [[nodiscard]] Prim : Expr {
  explicit Prim(BasicType v) : value_(std::move(v)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return fn ? std::forward<TransformFn>(fn)(value_) : value_;
  }

  [[nodiscard]] auto toString() const -> std::string override {
    return format("(Prim ", value_, ")");
  }

private:
  const BasicType value_;
};

struct [[nodiscard]] ExprBinOp : Expr {
  enum class Op : std::int8_t { Mul, Div, Add, Sub };

  ExprBinOp(Op op, Ptr<const Expr> lhs, Ptr<const Expr> rhs)
      : op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}

  auto result(TransformFn &&fn) const -> BasicType override {
    return act(lhs_->result(std::forward<TransformFn>(fn)),
               rhs_->result(std::forward<TransformFn>(fn)));
  }

  [[nodiscard]] auto toString() const -> std::string override {
    return format('(', opToString(), ' ', lhs_->toString(), ' ', rhs_->toString(), ')');
  }

private:
  [[nodiscard]] auto act(const BasicType &l, const BasicType &r) const -> BasicType {
    switch (op_) {
    case Op::Mul:
      return l * r;
    case Op::Div:
      return l / r;
    case Op::Add:
      return l + r;
    case Op::Sub:
      return l - r;
    }
    return {}; // shouldn't reach here
  }

  [[nodiscard]] auto opToString() const -> std::string_view {
    static constexpr std::array os{"*"sv, "/"sv, "+"sv, "-"sv};
    return os[static_cast<std::size_t>(op_)];
  }

  const Op op_;
  const Ptr<const Expr> lhs_;
  const Ptr<const Expr> rhs_;
};

void happy() {
  const auto createAST = []() -> Ptr<const Expr> {
    Ptr<const Expr> one = MkPtr<const Prim>(Double(1.1));
    Ptr<const Expr> two = MkPtr<const Prim>(Double(2.2));
    Ptr<const Expr> thr = MkPtr<const Prim>(Double(3.3));

    Ptr<const Expr> mul =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Mul, std::move(two), std::move(thr));

    Ptr<const Expr> add =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Add, std::move(one), std::move(mul));
    return add;
  };

  const auto ast = createAST();

  std::cout << "#### 1 + 2 * 3\n";
  std::cout << ast->toString() << '\n';
  std::cout << ast->result() << '\n';

  std::cout << "\n\n-- to int type\n";
  std::cout << "(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = ";
  std::cout << ast->result([](const BasicType &v) -> BasicType {
    const auto d = std::get<Double>(v).value();
    return Int(std::floor(d * d));
  }) << '\n';
}

void divByZero() {
  const auto createAST = []() -> Ptr<const Expr> {
    Ptr<const Expr> one = MkPtr<const Prim>(Int(1));
    Ptr<const Expr> two = MkPtr<const Prim>(Int(2));
    Ptr<const Expr> zer = MkPtr<const Prim>(Int(0));

    Ptr<const Expr> div =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Div, std::move(two), std::move(zer));

    Ptr<const Expr> sub =
        MkPtr<const ExprBinOp>(ExprBinOp::Op::Sub, std::move(one), std::move(div));
    return sub;
  };

  const auto ast = createAST();

  std::cout << "#### 1 - 2 / 0\n";
  std::cout << ast->toString() << '\n';
  std::cout << ast->result() << '\n';
}

auto main() -> int {
  happy();

  std::cout << "\n\n\n";

  divByZero();
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment